sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
 - sglang/lang/chat_template.py +17 -0
 - sglang/launch_server_llavavid.py +1 -1
 - sglang/srt/configs/__init__.py +3 -0
 - sglang/srt/configs/model_config.py +27 -2
 - sglang/srt/configs/qwen2vl.py +133 -0
 - sglang/srt/constrained/fsm_cache.py +10 -3
 - sglang/srt/conversation.py +27 -0
 - sglang/srt/hf_transformers_utils.py +16 -1
 - sglang/srt/layers/attention/__init__.py +16 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
 - sglang/srt/layers/attention/flashinfer_backend.py +174 -54
 - sglang/srt/layers/attention/triton_backend.py +22 -6
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
 - sglang/srt/layers/linear.py +89 -63
 - sglang/srt/layers/logits_processor.py +5 -5
 - sglang/srt/layers/rotary_embedding.py +112 -0
 - sglang/srt/layers/sampler.py +51 -39
 - sglang/srt/lora/lora.py +3 -1
 - sglang/srt/managers/data_parallel_controller.py +1 -1
 - sglang/srt/managers/detokenizer_manager.py +4 -0
 - sglang/srt/managers/image_processor.py +186 -13
 - sglang/srt/managers/io_struct.py +10 -0
 - sglang/srt/managers/schedule_batch.py +238 -68
 - sglang/srt/managers/scheduler.py +69 -50
 - sglang/srt/managers/tokenizer_manager.py +24 -4
 - sglang/srt/managers/tp_worker.py +26 -111
 - sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
 - sglang/srt/mem_cache/memory_pool.py +56 -10
 - sglang/srt/mem_cache/radix_cache.py +4 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +87 -28
 - sglang/srt/model_executor/forward_batch_info.py +83 -3
 - sglang/srt/model_executor/model_runner.py +32 -11
 - sglang/srt/models/chatglm.py +3 -3
 - sglang/srt/models/deepseek_v2.py +2 -2
 - sglang/srt/models/mllama.py +1004 -0
 - sglang/srt/models/qwen2_vl.py +724 -0
 - sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
 - sglang/srt/sampling/sampling_batch_info.py +13 -3
 - sglang/srt/sampling/sampling_params.py +5 -7
 - sglang/srt/server.py +12 -0
 - sglang/srt/server_args.py +10 -0
 - sglang/srt/utils.py +22 -0
 - sglang/test/run_eval.py +2 -0
 - sglang/test/runners.py +20 -1
 - sglang/test/srt/sampling/penaltylib/utils.py +1 -0
 - sglang/test/test_utils.py +100 -3
 - sglang/version.py +1 -1
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -0,0 +1,209 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            """
         
     | 
| 
      
 2 
     | 
    
         
            +
            Copyright 2023-2024 SGLang Team
         
     | 
| 
      
 3 
     | 
    
         
            +
            Licensed under the Apache License, Version 2.0 (the "License");
         
     | 
| 
      
 4 
     | 
    
         
            +
            you may not use this file except in compliance with the License.
         
     | 
| 
      
 5 
     | 
    
         
            +
            You may obtain a copy of the License at
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
                http://www.apache.org/licenses/LICENSE-2.0
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            Unless required by applicable law or agreed to in writing, software
         
     | 
| 
      
 10 
     | 
    
         
            +
            distributed under the License is distributed on an "AS IS" BASIS,
         
     | 
| 
      
 11 
     | 
    
         
            +
            WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 
      
 12 
     | 
    
         
            +
            See the License for the specific language governing permissions and
         
     | 
| 
      
 13 
     | 
    
         
            +
            limitations under the License.
         
     | 
| 
      
 14 
     | 
    
         
            +
            """
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            """A tensor parallel worker."""
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            import logging
         
     | 
| 
      
 19 
     | 
    
         
            +
            import threading
         
     | 
| 
      
 20 
     | 
    
         
            +
            import time
         
     | 
| 
      
 21 
     | 
    
         
            +
            from queue import Queue
         
     | 
| 
      
 22 
     | 
    
         
            +
            from typing import Optional
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 25 
     | 
    
         
            +
             
     | 
| 
      
 26 
     | 
    
         
            +
            from sglang.srt.managers.io_struct import UpdateWeightReqInput
         
     | 
| 
      
 27 
     | 
    
         
            +
            from sglang.srt.managers.schedule_batch import ModelWorkerBatch
         
     | 
| 
      
 28 
     | 
    
         
            +
            from sglang.srt.managers.tp_worker import TpModelWorker
         
     | 
| 
      
 29 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
      
 30 
     | 
    
         
            +
            from sglang.srt.server_args import ServerArgs
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            logger = logging.getLogger(__name__)
         
     | 
| 
      
 33 
     | 
    
         
            +
             
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
            @torch.compile(dynamic=True)
         
     | 
| 
      
 36 
     | 
    
         
            +
            def resolve_future_token_ids(input_ids, future_token_ids_map):
         
     | 
| 
      
 37 
     | 
    
         
            +
                input_ids[:] = torch.where(
         
     | 
| 
      
 38 
     | 
    
         
            +
                    input_ids < 0,
         
     | 
| 
      
 39 
     | 
    
         
            +
                    future_token_ids_map[torch.clamp(-input_ids, min=0)],
         
     | 
| 
      
 40 
     | 
    
         
            +
                    input_ids,
         
     | 
| 
      
 41 
     | 
    
         
            +
                )
         
     | 
| 
      
 42 
     | 
    
         
            +
             
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
            class TpModelWorkerClient:
         
     | 
| 
      
 45 
     | 
    
         
            +
                """A tensor parallel model worker."""
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                def __init__(
         
     | 
| 
      
 48 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 49 
     | 
    
         
            +
                    server_args: ServerArgs,
         
     | 
| 
      
 50 
     | 
    
         
            +
                    gpu_id: int,
         
     | 
| 
      
 51 
     | 
    
         
            +
                    tp_rank: int,
         
     | 
| 
      
 52 
     | 
    
         
            +
                    dp_rank: Optional[int],
         
     | 
| 
      
 53 
     | 
    
         
            +
                    nccl_port: int,
         
     | 
| 
      
 54 
     | 
    
         
            +
                ):
         
     | 
| 
      
 55 
     | 
    
         
            +
                    # Load the model
         
     | 
| 
      
 56 
     | 
    
         
            +
                    self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
         
     | 
| 
      
 57 
     | 
    
         
            +
                    self.max_running_requests = self.worker.max_running_requests
         
     | 
| 
      
 58 
     | 
    
         
            +
                    self.device = self.worker.device
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                    # Init future mappings
         
     | 
| 
      
 61 
     | 
    
         
            +
                    self.future_token_ids_ct = 0
         
     | 
| 
      
 62 
     | 
    
         
            +
                    self.future_token_ids_limit = self.max_running_requests * 3
         
     | 
| 
      
 63 
     | 
    
         
            +
                    self.future_token_ids_map = torch.empty(
         
     | 
| 
      
 64 
     | 
    
         
            +
                        (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
         
     | 
| 
      
 65 
     | 
    
         
            +
                    )
         
     | 
| 
      
 66 
     | 
    
         
            +
             
     | 
| 
      
 67 
     | 
    
         
            +
                    # Launch threads
         
     | 
| 
      
 68 
     | 
    
         
            +
                    self.input_queue = Queue()
         
     | 
| 
      
 69 
     | 
    
         
            +
                    self.output_queue = Queue()
         
     | 
| 
      
 70 
     | 
    
         
            +
                    self.forward_stream = torch.cuda.Stream()
         
     | 
| 
      
 71 
     | 
    
         
            +
                    self.forward_thread = threading.Thread(
         
     | 
| 
      
 72 
     | 
    
         
            +
                        target=self.forward_thread_func,
         
     | 
| 
      
 73 
     | 
    
         
            +
                    )
         
     | 
| 
      
 74 
     | 
    
         
            +
                    self.forward_thread.start()
         
     | 
| 
      
 75 
     | 
    
         
            +
             
     | 
| 
      
 76 
     | 
    
         
            +
                    self.copy_queue = Queue()
         
     | 
| 
      
 77 
     | 
    
         
            +
                    self.copy_thread = threading.Thread(
         
     | 
| 
      
 78 
     | 
    
         
            +
                        target=self.copy_thread_func,
         
     | 
| 
      
 79 
     | 
    
         
            +
                    )
         
     | 
| 
      
 80 
     | 
    
         
            +
                    self.copy_thread.start()
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
                def get_worker_info(self):
         
     | 
| 
      
 83 
     | 
    
         
            +
                    return self.worker.get_worker_info()
         
     | 
| 
      
 84 
     | 
    
         
            +
             
     | 
| 
      
 85 
     | 
    
         
            +
                def get_pad_input_ids_func(self):
         
     | 
| 
      
 86 
     | 
    
         
            +
                    return self.worker.get_pad_input_ids_func()
         
     | 
| 
      
 87 
     | 
    
         
            +
             
     | 
| 
      
 88 
     | 
    
         
            +
                def get_tp_cpu_group(self):
         
     | 
| 
      
 89 
     | 
    
         
            +
                    return self.worker.get_tp_cpu_group()
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
                def get_memory_pool(self):
         
     | 
| 
      
 92 
     | 
    
         
            +
                    return (
         
     | 
| 
      
 93 
     | 
    
         
            +
                        self.worker.model_runner.req_to_token_pool,
         
     | 
| 
      
 94 
     | 
    
         
            +
                        self.worker.model_runner.token_to_kv_pool,
         
     | 
| 
      
 95 
     | 
    
         
            +
                    )
         
     | 
| 
      
 96 
     | 
    
         
            +
             
     | 
| 
      
 97 
     | 
    
         
            +
                def forward_thread_func(self):
         
     | 
| 
      
 98 
     | 
    
         
            +
                    with torch.cuda.stream(self.forward_stream):
         
     | 
| 
      
 99 
     | 
    
         
            +
                        self.forward_thread_func_()
         
     | 
| 
      
 100 
     | 
    
         
            +
             
     | 
| 
      
 101 
     | 
    
         
            +
                @torch.inference_mode()
         
     | 
| 
      
 102 
     | 
    
         
            +
                def forward_thread_func_(self):
         
     | 
| 
      
 103 
     | 
    
         
            +
                    while True:
         
     | 
| 
      
 104 
     | 
    
         
            +
                        self.has_inflight_batch = False
         
     | 
| 
      
 105 
     | 
    
         
            +
                        model_worker_batch, future_token_ids_ct = self.input_queue.get()
         
     | 
| 
      
 106 
     | 
    
         
            +
                        if not model_worker_batch:
         
     | 
| 
      
 107 
     | 
    
         
            +
                            break
         
     | 
| 
      
 108 
     | 
    
         
            +
                        self.has_inflight_batch = True
         
     | 
| 
      
 109 
     | 
    
         
            +
                        self.launch_event = threading.Event()
         
     | 
| 
      
 110 
     | 
    
         
            +
             
     | 
| 
      
 111 
     | 
    
         
            +
                        # Resolve future tokens in the input
         
     | 
| 
      
 112 
     | 
    
         
            +
                        input_ids = model_worker_batch.input_ids
         
     | 
| 
      
 113 
     | 
    
         
            +
                        resolve_future_token_ids(input_ids, self.future_token_ids_map)
         
     | 
| 
      
 114 
     | 
    
         
            +
             
     | 
| 
      
 115 
     | 
    
         
            +
                        # Run forward
         
     | 
| 
      
 116 
     | 
    
         
            +
                        logits_output, next_token_ids = self.worker.forward_batch_generation(
         
     | 
| 
      
 117 
     | 
    
         
            +
                            model_worker_batch
         
     | 
| 
      
 118 
     | 
    
         
            +
                        )
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
                        # Update the future token ids map
         
     | 
| 
      
 121 
     | 
    
         
            +
                        bs = len(model_worker_batch.seq_lens)
         
     | 
| 
      
 122 
     | 
    
         
            +
                        self.future_token_ids_map[
         
     | 
| 
      
 123 
     | 
    
         
            +
                            future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
         
     | 
| 
      
 124 
     | 
    
         
            +
                        ] = next_token_ids
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
                        # Copy results to the CPU
         
     | 
| 
      
 127 
     | 
    
         
            +
                        if model_worker_batch.return_logprob:
         
     | 
| 
      
 128 
     | 
    
         
            +
                            logits_output.next_token_logprobs = logits_output.next_token_logprobs[
         
     | 
| 
      
 129 
     | 
    
         
            +
                                torch.arange(len(next_token_ids), device=self.device),
         
     | 
| 
      
 130 
     | 
    
         
            +
                                next_token_ids,
         
     | 
| 
      
 131 
     | 
    
         
            +
                            ].to("cpu", non_blocking=True)
         
     | 
| 
      
 132 
     | 
    
         
            +
                            if logits_output.input_token_logprobs is not None:
         
     | 
| 
      
 133 
     | 
    
         
            +
                                logits_output.input_token_logprobs = (
         
     | 
| 
      
 134 
     | 
    
         
            +
                                    logits_output.input_token_logprobs.to("cpu", non_blocking=True)
         
     | 
| 
      
 135 
     | 
    
         
            +
                                )
         
     | 
| 
      
 136 
     | 
    
         
            +
                                logits_output.normalized_prompt_logprobs = (
         
     | 
| 
      
 137 
     | 
    
         
            +
                                    logits_output.normalized_prompt_logprobs.to(
         
     | 
| 
      
 138 
     | 
    
         
            +
                                        "cpu", non_blocking=True
         
     | 
| 
      
 139 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 140 
     | 
    
         
            +
                                )
         
     | 
| 
      
 141 
     | 
    
         
            +
                        next_token_ids = next_token_ids.to("cpu", non_blocking=True)
         
     | 
| 
      
 142 
     | 
    
         
            +
                        copy_event = torch.cuda.Event(blocking=True)
         
     | 
| 
      
 143 
     | 
    
         
            +
                        copy_event.record()
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                        self.launch_event.set()
         
     | 
| 
      
 146 
     | 
    
         
            +
                        self.copy_queue.put((copy_event, logits_output, next_token_ids))
         
     | 
| 
      
 147 
     | 
    
         
            +
             
     | 
| 
      
 148 
     | 
    
         
            +
                def copy_thread_func(self):
         
     | 
| 
      
 149 
     | 
    
         
            +
                    while True:
         
     | 
| 
      
 150 
     | 
    
         
            +
                        copy_event, logits_output, next_token_ids = self.copy_queue.get()
         
     | 
| 
      
 151 
     | 
    
         
            +
                        if not copy_event:
         
     | 
| 
      
 152 
     | 
    
         
            +
                            break
         
     | 
| 
      
 153 
     | 
    
         
            +
                        while not copy_event.query():
         
     | 
| 
      
 154 
     | 
    
         
            +
                            time.sleep(1e-5)
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                        if logits_output.next_token_logprobs is not None:
         
     | 
| 
      
 157 
     | 
    
         
            +
                            logits_output.next_token_logprobs = (
         
     | 
| 
      
 158 
     | 
    
         
            +
                                logits_output.next_token_logprobs.tolist()
         
     | 
| 
      
 159 
     | 
    
         
            +
                            )
         
     | 
| 
      
 160 
     | 
    
         
            +
                            if logits_output.input_token_logprobs is not None:
         
     | 
| 
      
 161 
     | 
    
         
            +
                                logits_output.input_token_logprobs = (
         
     | 
| 
      
 162 
     | 
    
         
            +
                                    logits_output.input_token_logprobs.tolist()
         
     | 
| 
      
 163 
     | 
    
         
            +
                                )
         
     | 
| 
      
 164 
     | 
    
         
            +
                                logits_output.normalized_prompt_logprobs = (
         
     | 
| 
      
 165 
     | 
    
         
            +
                                    logits_output.normalized_prompt_logprobs.tolist()
         
     | 
| 
      
 166 
     | 
    
         
            +
                                )
         
     | 
| 
      
 167 
     | 
    
         
            +
             
     | 
| 
      
 168 
     | 
    
         
            +
                        self.output_queue.put((logits_output, next_token_ids.tolist()))
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                def resulve_batch_result(self, bid: int):
         
     | 
| 
      
 171 
     | 
    
         
            +
                    logits_output, next_token_ids = self.output_queue.get()
         
     | 
| 
      
 172 
     | 
    
         
            +
                    if self.has_inflight_batch:
         
     | 
| 
      
 173 
     | 
    
         
            +
                        # Wait until the batch is launched
         
     | 
| 
      
 174 
     | 
    
         
            +
                        self.launch_event.wait()
         
     | 
| 
      
 175 
     | 
    
         
            +
                    return logits_output, next_token_ids
         
     | 
| 
      
 176 
     | 
    
         
            +
             
     | 
| 
      
 177 
     | 
    
         
            +
                def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
         
     | 
| 
      
 178 
     | 
    
         
            +
                    # Push a new batch to the queue
         
     | 
| 
      
 179 
     | 
    
         
            +
                    self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
         
     | 
| 
      
 180 
     | 
    
         
            +
             
     | 
| 
      
 181 
     | 
    
         
            +
                    # Allocate output future objects
         
     | 
| 
      
 182 
     | 
    
         
            +
                    bs = len(model_worker_batch.seq_lens)
         
     | 
| 
      
 183 
     | 
    
         
            +
                    future_next_token_ids = torch.arange(
         
     | 
| 
      
 184 
     | 
    
         
            +
                        -(self.future_token_ids_ct + 1),
         
     | 
| 
      
 185 
     | 
    
         
            +
                        -(self.future_token_ids_ct + 1 + bs),
         
     | 
| 
      
 186 
     | 
    
         
            +
                        -1,
         
     | 
| 
      
 187 
     | 
    
         
            +
                        dtype=torch.int32,
         
     | 
| 
      
 188 
     | 
    
         
            +
                        device=self.device,
         
     | 
| 
      
 189 
     | 
    
         
            +
                    )
         
     | 
| 
      
 190 
     | 
    
         
            +
                    self.future_token_ids_ct = (
         
     | 
| 
      
 191 
     | 
    
         
            +
                        self.future_token_ids_ct + bs
         
     | 
| 
      
 192 
     | 
    
         
            +
                    ) % self.future_token_ids_limit
         
     | 
| 
      
 193 
     | 
    
         
            +
                    return None, future_next_token_ids
         
     | 
| 
      
 194 
     | 
    
         
            +
             
     | 
| 
      
 195 
     | 
    
         
            +
                def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
         
     | 
| 
      
 196 
     | 
    
         
            +
                    forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
         
     | 
| 
      
 197 
     | 
    
         
            +
                    logits_output = self.model_runner.forward(forward_batch)
         
     | 
| 
      
 198 
     | 
    
         
            +
                    embeddings = logits_output.embeddings
         
     | 
| 
      
 199 
     | 
    
         
            +
                    return embeddings
         
     | 
| 
      
 200 
     | 
    
         
            +
             
     | 
| 
      
 201 
     | 
    
         
            +
                def update_weights(self, recv_req: UpdateWeightReqInput):
         
     | 
| 
      
 202 
     | 
    
         
            +
                    success, message = self.model_runner.update_weights(
         
     | 
| 
      
 203 
     | 
    
         
            +
                        recv_req.model_path, recv_req.load_format
         
     | 
| 
      
 204 
     | 
    
         
            +
                    )
         
     | 
| 
      
 205 
     | 
    
         
            +
                    return success, message
         
     | 
| 
      
 206 
     | 
    
         
            +
             
     | 
| 
      
 207 
     | 
    
         
            +
                def __delete__(self):
         
     | 
| 
      
 208 
     | 
    
         
            +
                    self.input_queue.put((None, None))
         
     | 
| 
      
 209 
     | 
    
         
            +
                    self.copy_queue.put((None, None, None))
         
     | 
| 
         @@ -13,27 +13,46 @@ See the License for the specific language governing permissions and 
     | 
|
| 
       13 
13 
     | 
    
         
             
            limitations under the License.
         
     | 
| 
       14 
14 
     | 
    
         
             
            """
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
       16 
     | 
    
         
            -
            """ 
     | 
| 
      
 16 
     | 
    
         
            +
            """
         
     | 
| 
      
 17 
     | 
    
         
            +
            Memory pool.
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            SGLang has two levels of memory pool.
         
     | 
| 
      
 20 
     | 
    
         
            +
            ReqToTokenPool maps a a request to its token locations.
         
     | 
| 
      
 21 
     | 
    
         
            +
            BaseTokenToKVPool maps a token location to its KV cache data.
         
     | 
| 
      
 22 
     | 
    
         
            +
            """
         
     | 
| 
       17 
23 
     | 
    
         | 
| 
       18 
24 
     | 
    
         
             
            import logging
         
     | 
| 
       19 
25 
     | 
    
         
             
            from typing import List, Tuple, Union
         
     | 
| 
       20 
26 
     | 
    
         | 
| 
       21 
27 
     | 
    
         
             
            import torch
         
     | 
| 
       22 
28 
     | 
    
         | 
| 
      
 29 
     | 
    
         
            +
            from sglang.srt.layers.radix_attention import RadixAttention
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
       23 
31 
     | 
    
         
             
            logger = logging.getLogger(__name__)
         
     | 
| 
       24 
32 
     | 
    
         | 
| 
       25 
33 
     | 
    
         | 
| 
       26 
34 
     | 
    
         
             
            class ReqToTokenPool:
         
     | 
| 
       27 
35 
     | 
    
         
             
                """A memory pool that maps a request to its token locations."""
         
     | 
| 
       28 
36 
     | 
    
         | 
| 
       29 
     | 
    
         
            -
                def __init__(self, size: int, max_context_len: int, device: str):
         
     | 
| 
      
 37 
     | 
    
         
            +
                def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
         
     | 
| 
       30 
38 
     | 
    
         
             
                    self.size = size
         
     | 
| 
       31 
39 
     | 
    
         
             
                    self.max_context_len = max_context_len
         
     | 
| 
       32 
40 
     | 
    
         
             
                    self.device = device
         
     | 
| 
       33 
     | 
    
         
            -
                    self.req_to_token = torch. 
     | 
| 
      
 41 
     | 
    
         
            +
                    self.req_to_token = torch.zeros(
         
     | 
| 
       34 
42 
     | 
    
         
             
                        (size, max_context_len), dtype=torch.int32, device=device
         
     | 
| 
       35 
43 
     | 
    
         
             
                    )
         
     | 
| 
       36 
44 
     | 
    
         
             
                    self.free_slots = list(range(size))
         
     | 
| 
      
 45 
     | 
    
         
            +
                    self.write_records = []
         
     | 
| 
      
 46 
     | 
    
         
            +
                    self.use_records = use_records
         
     | 
| 
      
 47 
     | 
    
         
            +
             
     | 
| 
      
 48 
     | 
    
         
            +
                    if self.use_records:
         
     | 
| 
      
 49 
     | 
    
         
            +
                        self.write = self.write_with_records
         
     | 
| 
      
 50 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 51 
     | 
    
         
            +
                        self.write = self.write_without_records
         
     | 
| 
      
 52 
     | 
    
         
            +
             
     | 
| 
      
 53 
     | 
    
         
            +
                def write(self, indices, values):
         
     | 
| 
      
 54 
     | 
    
         
            +
                    # Keep the signature for type checking. It will be assigned during runtime.
         
     | 
| 
      
 55 
     | 
    
         
            +
                    raise NotImplementedError()
         
     | 
| 
       37 
56 
     | 
    
         | 
| 
       38 
57 
     | 
    
         
             
                def available_size(self):
         
     | 
| 
       39 
58 
     | 
    
         
             
                    return len(self.free_slots)
         
     | 
| 
         @@ -55,10 +74,27 @@ class ReqToTokenPool: 
     | 
|
| 
       55 
74 
     | 
    
         | 
| 
       56 
75 
     | 
    
         
             
                def clear(self):
         
     | 
| 
       57 
76 
     | 
    
         
             
                    self.free_slots = list(range(self.size))
         
     | 
| 
      
 77 
     | 
    
         
            +
                    self.write_records = []
         
     | 
| 
      
 78 
     | 
    
         
            +
             
     | 
| 
      
 79 
     | 
    
         
            +
                def write_without_records(self, indices, values):
         
     | 
| 
      
 80 
     | 
    
         
            +
                    self.req_to_token[indices] = values
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
                def write_with_records(self, indices, values):
         
     | 
| 
      
 83 
     | 
    
         
            +
                    self.req_to_token[indices] = values
         
     | 
| 
      
 84 
     | 
    
         
            +
                    self.write_records.append((indices, values))
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                def get_write_records(self):
         
     | 
| 
      
 87 
     | 
    
         
            +
                    ret = self.write_records
         
     | 
| 
      
 88 
     | 
    
         
            +
                    self.write_records = []
         
     | 
| 
      
 89 
     | 
    
         
            +
                    return ret
         
     | 
| 
      
 90 
     | 
    
         
            +
             
     | 
| 
      
 91 
     | 
    
         
            +
                def apply_write_records(self, write_records: List[Tuple]):
         
     | 
| 
      
 92 
     | 
    
         
            +
                    for indices, values in write_records:
         
     | 
| 
      
 93 
     | 
    
         
            +
                        self.req_to_token[indices] = values
         
     | 
| 
       58 
94 
     | 
    
         | 
| 
       59 
95 
     | 
    
         | 
| 
       60 
96 
     | 
    
         
             
            class BaseTokenToKVPool:
         
     | 
| 
       61 
     | 
    
         
            -
                """A memory pool that maps a token to its kv cache  
     | 
| 
      
 97 
     | 
    
         
            +
                """A memory pool that maps a token location to its kv cache data."""
         
     | 
| 
       62 
98 
     | 
    
         | 
| 
       63 
99 
     | 
    
         
             
                def __init__(
         
     | 
| 
       64 
100 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -68,12 +104,12 @@ class BaseTokenToKVPool: 
     | 
|
| 
       68 
104 
     | 
    
         
             
                ):
         
     | 
| 
       69 
105 
     | 
    
         
             
                    self.size = size
         
     | 
| 
       70 
106 
     | 
    
         
             
                    self.dtype = dtype
         
     | 
| 
       71 
     | 
    
         
            -
                    self.device = device
         
     | 
| 
       72 
107 
     | 
    
         
             
                    if dtype == torch.float8_e5m2:
         
     | 
| 
       73 
108 
     | 
    
         
             
                        # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
         
     | 
| 
       74 
109 
     | 
    
         
             
                        self.store_dtype = torch.uint8
         
     | 
| 
       75 
110 
     | 
    
         
             
                    else:
         
     | 
| 
       76 
111 
     | 
    
         
             
                        self.store_dtype = dtype
         
     | 
| 
      
 112 
     | 
    
         
            +
                    self.device = device
         
     | 
| 
       77 
113 
     | 
    
         | 
| 
       78 
114 
     | 
    
         
             
                    self.free_slots = None
         
     | 
| 
       79 
115 
     | 
    
         
             
                    self.is_not_in_free_group = True
         
     | 
| 
         @@ -124,7 +160,7 @@ class BaseTokenToKVPool: 
     | 
|
| 
       124 
160 
     | 
    
         | 
| 
       125 
161 
     | 
    
         
             
                def set_kv_buffer(
         
     | 
| 
       126 
162 
     | 
    
         
             
                    self,
         
     | 
| 
       127 
     | 
    
         
            -
                     
     | 
| 
      
 163 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
       128 
164 
     | 
    
         
             
                    loc: torch.Tensor,
         
     | 
| 
       129 
165 
     | 
    
         
             
                    cache_k: torch.Tensor,
         
     | 
| 
       130 
166 
     | 
    
         
             
                    cache_v: torch.Tensor,
         
     | 
| 
         @@ -179,14 +215,14 @@ class MHATokenToKVPool(BaseTokenToKVPool): 
     | 
|
| 
       179 
215 
     | 
    
         | 
| 
       180 
216 
     | 
    
         
             
                def set_kv_buffer(
         
     | 
| 
       181 
217 
     | 
    
         
             
                    self,
         
     | 
| 
       182 
     | 
    
         
            -
                     
     | 
| 
      
 218 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
       183 
219 
     | 
    
         
             
                    loc: torch.Tensor,
         
     | 
| 
       184 
220 
     | 
    
         
             
                    cache_k: torch.Tensor,
         
     | 
| 
       185 
221 
     | 
    
         
             
                    cache_v: torch.Tensor,
         
     | 
| 
       186 
222 
     | 
    
         
             
                ):
         
     | 
| 
      
 223 
     | 
    
         
            +
                    layer_id = layer.layer_id
         
     | 
| 
       187 
224 
     | 
    
         
             
                    if cache_k.dtype != self.dtype:
         
     | 
| 
       188 
225 
     | 
    
         
             
                        cache_k = cache_k.to(self.dtype)
         
     | 
| 
       189 
     | 
    
         
            -
                    if cache_v.dtype != self.dtype:
         
     | 
| 
       190 
226 
     | 
    
         
             
                        cache_v = cache_v.to(self.dtype)
         
     | 
| 
       191 
227 
     | 
    
         
             
                    if self.store_dtype != self.dtype:
         
     | 
| 
       192 
228 
     | 
    
         
             
                        self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
         
     | 
| 
         @@ -196,6 +232,14 @@ class MHATokenToKVPool(BaseTokenToKVPool): 
     | 
|
| 
       196 
232 
     | 
    
         
             
                        self.v_buffer[layer_id][loc] = cache_v
         
     | 
| 
       197 
233 
     | 
    
         | 
| 
       198 
234 
     | 
    
         | 
| 
      
 235 
     | 
    
         
            +
            # This compiled version is slower in the unit test
         
     | 
| 
      
 236 
     | 
    
         
            +
            # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
         
     | 
| 
      
 237 
     | 
    
         
            +
            @torch.compile(dynamic=True)
         
     | 
| 
      
 238 
     | 
    
         
            +
            def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
         
     | 
| 
      
 239 
     | 
    
         
            +
                dst_1[loc] = src_1.to(dtype).view(store_dtype)
         
     | 
| 
      
 240 
     | 
    
         
            +
                dst_2[loc] = src_2.to(dtype).view(store_dtype)
         
     | 
| 
      
 241 
     | 
    
         
            +
             
     | 
| 
      
 242 
     | 
    
         
            +
             
     | 
| 
       199 
243 
     | 
    
         
             
            class MLATokenToKVPool(BaseTokenToKVPool):
         
     | 
| 
       200 
244 
     | 
    
         | 
| 
       201 
245 
     | 
    
         
             
                def __init__(
         
     | 
| 
         @@ -235,11 +279,12 @@ class MLATokenToKVPool(BaseTokenToKVPool): 
     | 
|
| 
       235 
279 
     | 
    
         | 
| 
       236 
280 
     | 
    
         
             
                def set_kv_buffer(
         
     | 
| 
       237 
281 
     | 
    
         
             
                    self,
         
     | 
| 
       238 
     | 
    
         
            -
                     
     | 
| 
      
 282 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
       239 
283 
     | 
    
         
             
                    loc: torch.Tensor,
         
     | 
| 
       240 
284 
     | 
    
         
             
                    cache_k: torch.Tensor,
         
     | 
| 
       241 
285 
     | 
    
         
             
                    cache_v: torch.Tensor,
         
     | 
| 
       242 
286 
     | 
    
         
             
                ):
         
     | 
| 
      
 287 
     | 
    
         
            +
                    layer_id = layer.layer_id
         
     | 
| 
       243 
288 
     | 
    
         
             
                    if cache_k.dtype != self.dtype:
         
     | 
| 
       244 
289 
     | 
    
         
             
                        cache_k = cache_k.to(self.dtype)
         
     | 
| 
       245 
290 
     | 
    
         
             
                    if self.store_dtype != self.dtype:
         
     | 
| 
         @@ -294,13 +339,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool): 
     | 
|
| 
       294 
339 
     | 
    
         | 
| 
       295 
340 
     | 
    
         
             
                def set_kv_buffer(
         
     | 
| 
       296 
341 
     | 
    
         
             
                    self,
         
     | 
| 
       297 
     | 
    
         
            -
                     
     | 
| 
      
 342 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
       298 
343 
     | 
    
         
             
                    loc: torch.Tensor,
         
     | 
| 
       299 
344 
     | 
    
         
             
                    cache_k: torch.Tensor,
         
     | 
| 
       300 
345 
     | 
    
         
             
                    cache_v: torch.Tensor,
         
     | 
| 
       301 
346 
     | 
    
         
             
                    cache_label: torch.Tensor,
         
     | 
| 
       302 
347 
     | 
    
         
             
                ):
         
     | 
| 
       303 
348 
     | 
    
         
             
                    # NOTE(Andy): ignore the dtype check
         
     | 
| 
      
 349 
     | 
    
         
            +
                    layer_id = layer.layer_id
         
     | 
| 
       304 
350 
     | 
    
         
             
                    self.k_buffer[layer_id][loc] = cache_k
         
     | 
| 
       305 
351 
     | 
    
         
             
                    self.v_buffer[layer_id][loc] = cache_v
         
     | 
| 
       306 
352 
     | 
    
         
             
                    self.label_buffer[layer_id][loc] = cache_label
         
     | 
| 
         @@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache): 
     | 
|
| 
       145 
145 
     | 
    
         
             
                    # The prefix indices could be updated, reuse it
         
     | 
| 
       146 
146 
     | 
    
         
             
                    new_indices, new_last_node = self.match_prefix(token_ids)
         
     | 
| 
       147 
147 
     | 
    
         
             
                    assert len(new_indices) == len(token_ids)
         
     | 
| 
       148 
     | 
    
         
            -
                    self.req_to_token_pool. 
     | 
| 
       149 
     | 
    
         
            -
                        req.req_pool_idx, len(req.prefix_indices)  
     | 
| 
       150 
     | 
    
         
            -
             
     | 
| 
      
 148 
     | 
    
         
            +
                    self.req_to_token_pool.write(
         
     | 
| 
      
 149 
     | 
    
         
            +
                        (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
         
     | 
| 
      
 150 
     | 
    
         
            +
                        new_indices[len(req.prefix_indices) :],
         
     | 
| 
      
 151 
     | 
    
         
            +
                    )
         
     | 
| 
       151 
152 
     | 
    
         | 
| 
       152 
153 
     | 
    
         
             
                    self.dec_lock_ref(req.last_node)
         
     | 
| 
       153 
154 
     | 
    
         
             
                    self.inc_lock_ref(new_last_node)
         
     | 
| 
         @@ -92,6 +92,11 @@ def set_torch_compile_config(): 
     | 
|
| 
       92 
92 
     | 
    
         
             
                torch._dynamo.config.accumulated_cache_size_limit = 1024
         
     | 
| 
       93 
93 
     | 
    
         | 
| 
       94 
94 
     | 
    
         | 
| 
      
 95 
     | 
    
         
            +
            @torch.compile(dynamic=True)
         
     | 
| 
      
 96 
     | 
    
         
            +
            def clamp_position(seq_lens):
         
     | 
| 
      
 97 
     | 
    
         
            +
                return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
         
     | 
| 
      
 98 
     | 
    
         
            +
             
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
       95 
100 
     | 
    
         
             
            class CudaGraphRunner:
         
     | 
| 
       96 
101 
     | 
    
         
             
                """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
         
     | 
| 
       97 
102 
     | 
    
         | 
| 
         @@ -105,13 +110,13 @@ class CudaGraphRunner: 
     | 
|
| 
       105 
110 
     | 
    
         
             
                    self.graph_memory_pool = None
         
     | 
| 
       106 
111 
     | 
    
         
             
                    self.use_torch_compile = model_runner.server_args.enable_torch_compile
         
     | 
| 
       107 
112 
     | 
    
         
             
                    self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
         
     | 
| 
      
 113 
     | 
    
         
            +
                    self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
         
     | 
| 
       108 
114 
     | 
    
         | 
| 
       109 
115 
     | 
    
         
             
                    # Batch sizes to capture
         
     | 
| 
       110 
116 
     | 
    
         
             
                    if self.model_runner.server_args.disable_cuda_graph_padding:
         
     | 
| 
       111 
117 
     | 
    
         
             
                        self.capture_bs = list(range(1, 32)) + [64, 128]
         
     | 
| 
       112 
118 
     | 
    
         
             
                    else:
         
     | 
| 
       113 
     | 
    
         
            -
                        self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
         
     | 
| 
       114 
     | 
    
         
            -
             
     | 
| 
      
 119 
     | 
    
         
            +
                        self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
         
     | 
| 
       115 
120 
     | 
    
         
             
                    self.capture_bs = [
         
     | 
| 
       116 
121 
     | 
    
         
             
                        bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
         
     | 
| 
       117 
122 
     | 
    
         
             
                    ]
         
     | 
| 
         @@ -128,10 +133,14 @@ class CudaGraphRunner: 
     | 
|
| 
       128 
133 
     | 
    
         
             
                    # Attention backend
         
     | 
| 
       129 
134 
     | 
    
         
             
                    self.max_bs = max(self.capture_bs)
         
     | 
| 
       130 
135 
     | 
    
         
             
                    self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
       131 
137 
     | 
    
         
             
                    self.seq_len_fill_value = (
         
     | 
| 
       132 
138 
     | 
    
         
             
                        self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
         
     | 
| 
       133 
139 
     | 
    
         
             
                    )
         
     | 
| 
       134 
140 
     | 
    
         | 
| 
      
 141 
     | 
    
         
            +
                    # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
         
     | 
| 
      
 142 
     | 
    
         
            +
                    self.encoder_len_fill_value = 0
         
     | 
| 
      
 143 
     | 
    
         
            +
             
     | 
| 
       135 
144 
     | 
    
         
             
                    if self.use_torch_compile:
         
     | 
| 
       136 
145 
     | 
    
         
             
                        set_torch_compile_config()
         
     | 
| 
       137 
146 
     | 
    
         | 
| 
         @@ -143,10 +152,20 @@ class CudaGraphRunner: 
     | 
|
| 
       143 
152 
     | 
    
         
             
                            (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
         
     | 
| 
       144 
153 
     | 
    
         
             
                        )
         
     | 
| 
       145 
154 
     | 
    
         
             
                        self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
         
     | 
| 
      
 155 
     | 
    
         
            +
                        self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
         
     | 
| 
      
 156 
     | 
    
         
            +
             
     | 
| 
      
 157 
     | 
    
         
            +
                        if self.is_encoder_decoder:
         
     | 
| 
      
 158 
     | 
    
         
            +
                            # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
         
     | 
| 
      
 159 
     | 
    
         
            +
                            self.encoder_lens = torch.full(
         
     | 
| 
      
 160 
     | 
    
         
            +
                                (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
         
     | 
| 
      
 161 
     | 
    
         
            +
                            )
         
     | 
| 
      
 162 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 163 
     | 
    
         
            +
                            self.encoder_lens = None
         
     | 
| 
       146 
164 
     | 
    
         | 
| 
       147 
165 
     | 
    
         
             
                    # Capture
         
     | 
| 
       148 
166 
     | 
    
         
             
                    try:
         
     | 
| 
       149 
     | 
    
         
            -
                        self. 
     | 
| 
      
 167 
     | 
    
         
            +
                        with self.model_capture_mode():
         
     | 
| 
      
 168 
     | 
    
         
            +
                            self.capture()
         
     | 
| 
       150 
169 
     | 
    
         
             
                    except RuntimeError as e:
         
     | 
| 
       151 
170 
     | 
    
         
             
                        raise Exception(
         
     | 
| 
       152 
171 
     | 
    
         
             
                            f"Capture cuda graph failed: {e}\n"
         
     | 
| 
         @@ -157,11 +176,32 @@ class CudaGraphRunner: 
     | 
|
| 
       157 
176 
     | 
    
         
             
                            "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
         
     | 
| 
       158 
177 
     | 
    
         
             
                        )
         
     | 
| 
       159 
178 
     | 
    
         | 
| 
       160 
     | 
    
         
            -
                 
     | 
| 
       161 
     | 
    
         
            -
             
     | 
| 
       162 
     | 
    
         
            -
             
     | 
| 
       163 
     | 
    
         
            -
             
     | 
| 
       164 
     | 
    
         
            -
             
     | 
| 
      
 179 
     | 
    
         
            +
                @contextmanager
         
     | 
| 
      
 180 
     | 
    
         
            +
                def model_capture_mode(self):
         
     | 
| 
      
 181 
     | 
    
         
            +
                    if hasattr(self.model_runner.model, "capture_mode"):
         
     | 
| 
      
 182 
     | 
    
         
            +
                        self.model_runner.model.capture_mode = True
         
     | 
| 
      
 183 
     | 
    
         
            +
             
     | 
| 
      
 184 
     | 
    
         
            +
                    yield
         
     | 
| 
      
 185 
     | 
    
         
            +
             
     | 
| 
      
 186 
     | 
    
         
            +
                    if hasattr(self.model_runner.model, "capture_mode"):
         
     | 
| 
      
 187 
     | 
    
         
            +
                        self.model_runner.model.capture_mode = False
         
     | 
| 
      
 188 
     | 
    
         
            +
             
     | 
| 
      
 189 
     | 
    
         
            +
                def can_run(self, forward_batch: ForwardBatch):
         
     | 
| 
      
 190 
     | 
    
         
            +
                    is_bs_supported = (
         
     | 
| 
      
 191 
     | 
    
         
            +
                        forward_batch.batch_size in self.graphs
         
     | 
| 
      
 192 
     | 
    
         
            +
                        if self.disable_padding
         
     | 
| 
      
 193 
     | 
    
         
            +
                        else forward_batch.batch_size <= self.max_bs
         
     | 
| 
      
 194 
     | 
    
         
            +
                    )
         
     | 
| 
      
 195 
     | 
    
         
            +
             
     | 
| 
      
 196 
     | 
    
         
            +
                    # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
         
     | 
| 
      
 197 
     | 
    
         
            +
                    # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
         
     | 
| 
      
 198 
     | 
    
         
            +
                    # because the full_text_row_masked_out_mask tensor will always be ones
         
     | 
| 
      
 199 
     | 
    
         
            +
                    is_encoder_lens_supported = (
         
     | 
| 
      
 200 
     | 
    
         
            +
                        torch.all(forward_batch.encoder_lens > 0)
         
     | 
| 
      
 201 
     | 
    
         
            +
                        if self.is_encoder_decoder
         
     | 
| 
      
 202 
     | 
    
         
            +
                        else True
         
     | 
| 
      
 203 
     | 
    
         
            +
                    )
         
     | 
| 
      
 204 
     | 
    
         
            +
                    return is_bs_supported and is_encoder_lens_supported
         
     | 
| 
       165 
205 
     | 
    
         | 
| 
       166 
206 
     | 
    
         
             
                def capture(self):
         
     | 
| 
       167 
207 
     | 
    
         
             
                    with graph_capture() as graph_capture_context:
         
     | 
| 
         @@ -188,10 +228,20 @@ class CudaGraphRunner: 
     | 
|
| 
       188 
228 
     | 
    
         
             
                    req_pool_indices = self.req_pool_indices[:bs]
         
     | 
| 
       189 
229 
     | 
    
         
             
                    seq_lens = self.seq_lens[:bs]
         
     | 
| 
       190 
230 
     | 
    
         
             
                    out_cache_loc = self.out_cache_loc[:bs]
         
     | 
| 
      
 231 
     | 
    
         
            +
                    if self.is_encoder_decoder:
         
     | 
| 
      
 232 
     | 
    
         
            +
                        encoder_lens = self.encoder_lens[:bs]
         
     | 
| 
      
 233 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 234 
     | 
    
         
            +
                        encoder_lens = None
         
     | 
| 
      
 235 
     | 
    
         
            +
             
     | 
| 
      
 236 
     | 
    
         
            +
                    seq_lens_sum = seq_lens.sum().item()
         
     | 
| 
      
 237 
     | 
    
         
            +
                    mrope_positions = self.mrope_positions[:, :bs]
         
     | 
| 
       191 
238 
     | 
    
         | 
| 
       192 
239 
     | 
    
         
             
                    # Attention backend
         
     | 
| 
       193 
240 
     | 
    
         
             
                    self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
         
     | 
| 
       194 
     | 
    
         
            -
                        bs, 
     | 
| 
      
 241 
     | 
    
         
            +
                        bs,
         
     | 
| 
      
 242 
     | 
    
         
            +
                        req_pool_indices,
         
     | 
| 
      
 243 
     | 
    
         
            +
                        seq_lens,
         
     | 
| 
      
 244 
     | 
    
         
            +
                        encoder_lens,
         
     | 
| 
       195 
245 
     | 
    
         
             
                    )
         
     | 
| 
       196 
246 
     | 
    
         | 
| 
       197 
247 
     | 
    
         
             
                    # Run and capture
         
     | 
| 
         @@ -206,11 +256,15 @@ class CudaGraphRunner: 
     | 
|
| 
       206 
256 
     | 
    
         
             
                            token_to_kv_pool=self.model_runner.token_to_kv_pool,
         
     | 
| 
       207 
257 
     | 
    
         
             
                            attn_backend=self.model_runner.attn_backend,
         
     | 
| 
       208 
258 
     | 
    
         
             
                            out_cache_loc=out_cache_loc,
         
     | 
| 
      
 259 
     | 
    
         
            +
                            seq_lens_sum=seq_lens_sum,
         
     | 
| 
      
 260 
     | 
    
         
            +
                            encoder_lens=encoder_lens,
         
     | 
| 
       209 
261 
     | 
    
         
             
                            return_logprob=False,
         
     | 
| 
       210 
262 
     | 
    
         
             
                            top_logprobs_nums=[0] * bs,
         
     | 
| 
       211 
     | 
    
         
            -
                            positions= 
     | 
| 
      
 263 
     | 
    
         
            +
                            positions=clamp_position(seq_lens),
         
     | 
| 
      
 264 
     | 
    
         
            +
                            mrope_positions=mrope_positions,
         
     | 
| 
       212 
265 
     | 
    
         
             
                        )
         
     | 
| 
       213 
     | 
    
         
            -
                         
     | 
| 
      
 266 
     | 
    
         
            +
                        logits_output = forward(input_ids, forward_batch.positions, forward_batch)
         
     | 
| 
      
 267 
     | 
    
         
            +
                        return logits_output.next_token_logits
         
     | 
| 
       214 
268 
     | 
    
         | 
| 
       215 
269 
     | 
    
         
             
                    for _ in range(2):
         
     | 
| 
       216 
270 
     | 
    
         
             
                        torch.cuda.synchronize()
         
     | 
| 
         @@ -241,7 +295,7 @@ class CudaGraphRunner: 
     | 
|
| 
       241 
295 
     | 
    
         
             
                    index = bisect.bisect_left(self.capture_bs, raw_bs)
         
     | 
| 
       242 
296 
     | 
    
         
             
                    bs = self.capture_bs[index]
         
     | 
| 
       243 
297 
     | 
    
         
             
                    if bs != raw_bs:
         
     | 
| 
       244 
     | 
    
         
            -
                        self.seq_lens.fill_( 
     | 
| 
      
 298 
     | 
    
         
            +
                        self.seq_lens.fill_(1)
         
     | 
| 
       245 
299 
     | 
    
         
             
                        self.out_cache_loc.zero_()
         
     | 
| 
       246 
300 
     | 
    
         | 
| 
       247 
301 
     | 
    
         
             
                    # Common inputs
         
     | 
| 
         @@ -249,31 +303,32 @@ class CudaGraphRunner: 
     | 
|
| 
       249 
303 
     | 
    
         
             
                    self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
         
     | 
| 
       250 
304 
     | 
    
         
             
                    self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
         
     | 
| 
       251 
305 
     | 
    
         
             
                    self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
         
     | 
| 
      
 306 
     | 
    
         
            +
                    if self.is_encoder_decoder:
         
     | 
| 
      
 307 
     | 
    
         
            +
                        self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
         
     | 
| 
      
 308 
     | 
    
         
            +
                    if forward_batch.mrope_positions is not None:
         
     | 
| 
      
 309 
     | 
    
         
            +
                        self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
         
     | 
| 
       252 
310 
     | 
    
         | 
| 
       253 
311 
     | 
    
         
             
                    # Attention backend
         
     | 
| 
       254 
312 
     | 
    
         
             
                    self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
         
     | 
| 
       255 
     | 
    
         
            -
                        bs, 
     | 
| 
      
 313 
     | 
    
         
            +
                        bs,
         
     | 
| 
      
 314 
     | 
    
         
            +
                        self.req_pool_indices,
         
     | 
| 
      
 315 
     | 
    
         
            +
                        self.seq_lens,
         
     | 
| 
      
 316 
     | 
    
         
            +
                        forward_batch.seq_lens_sum + (bs - raw_bs),
         
     | 
| 
      
 317 
     | 
    
         
            +
                        self.encoder_lens,
         
     | 
| 
       256 
318 
     | 
    
         
             
                    )
         
     | 
| 
       257 
319 
     | 
    
         | 
| 
       258 
320 
     | 
    
         
             
                    # Replay
         
     | 
| 
       259 
321 
     | 
    
         
             
                    self.graphs[bs].replay()
         
     | 
| 
       260 
     | 
    
         
            -
                     
     | 
| 
       261 
     | 
    
         
            -
             
     | 
| 
       262 
     | 
    
         
            -
                    # Unpad
         
     | 
| 
       263 
     | 
    
         
            -
                    if bs != raw_bs:
         
     | 
| 
       264 
     | 
    
         
            -
                        logits_output = LogitsProcessorOutput(
         
     | 
| 
       265 
     | 
    
         
            -
                            next_token_logits=logits_output.next_token_logits[:raw_bs],
         
     | 
| 
       266 
     | 
    
         
            -
                            next_token_logprobs=None,
         
     | 
| 
       267 
     | 
    
         
            -
                            normalized_prompt_logprobs=None,
         
     | 
| 
       268 
     | 
    
         
            -
                            input_token_logprobs=None,
         
     | 
| 
       269 
     | 
    
         
            -
                            input_top_logprobs=None,
         
     | 
| 
       270 
     | 
    
         
            -
                            output_top_logprobs=None,
         
     | 
| 
       271 
     | 
    
         
            -
                        )
         
     | 
| 
      
 322 
     | 
    
         
            +
                    next_token_logits = self.output_buffers[bs][:raw_bs]
         
     | 
| 
       272 
323 
     | 
    
         | 
| 
       273 
324 
     | 
    
         
             
                    # Extract logprobs
         
     | 
| 
       274 
325 
     | 
    
         
             
                    if forward_batch.return_logprob:
         
     | 
| 
       275 
     | 
    
         
            -
                         
     | 
| 
       276 
     | 
    
         
            -
                             
     | 
| 
      
 326 
     | 
    
         
            +
                        next_token_logprobs = torch.nn.functional.log_softmax(
         
     | 
| 
      
 327 
     | 
    
         
            +
                            next_token_logits, dim=-1
         
     | 
| 
      
 328 
     | 
    
         
            +
                        )
         
     | 
| 
      
 329 
     | 
    
         
            +
                        logits_output = LogitsProcessorOutput(
         
     | 
| 
      
 330 
     | 
    
         
            +
                            next_token_logits=next_token_logits,
         
     | 
| 
      
 331 
     | 
    
         
            +
                            next_token_logprobs=next_token_logprobs,
         
     | 
| 
       277 
332 
     | 
    
         
             
                        )
         
     | 
| 
       278 
333 
     | 
    
         
             
                        return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
         
     | 
| 
       279 
334 
     | 
    
         
             
                        if return_top_logprob:
         
     | 
| 
         @@ -282,7 +337,11 @@ class CudaGraphRunner: 
     | 
|
| 
       282 
337 
     | 
    
         
             
                                top_logprobs_nums=forward_batch.top_logprobs_nums,
         
     | 
| 
       283 
338 
     | 
    
         
             
                            )
         
     | 
| 
       284 
339 
     | 
    
         
             
                            logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
         
     | 
| 
       285 
     | 
    
         
            -
                                 
     | 
| 
      
 340 
     | 
    
         
            +
                                next_token_logprobs, logits_metadata
         
     | 
| 
       286 
341 
     | 
    
         
             
                            )[1]
         
     | 
| 
      
 342 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 343 
     | 
    
         
            +
                        logits_output = LogitsProcessorOutput(
         
     | 
| 
      
 344 
     | 
    
         
            +
                            next_token_logits=next_token_logits,
         
     | 
| 
      
 345 
     | 
    
         
            +
                        )
         
     | 
| 
       287 
346 
     | 
    
         | 
| 
       288 
347 
     | 
    
         
             
                    return logits_output
         
     |