sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
 - sglang/bench_server_latency.py +21 -10
 - sglang/bench_serving.py +101 -7
 - sglang/global_config.py +0 -1
 - sglang/srt/layers/attention/__init__.py +27 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
 - sglang/srt/layers/attention/flashinfer_backend.py +352 -83
 - sglang/srt/layers/attention/triton_backend.py +6 -4
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
 - sglang/srt/layers/sampler.py +6 -2
 - sglang/srt/managers/detokenizer_manager.py +31 -10
 - sglang/srt/managers/io_struct.py +4 -0
 - sglang/srt/managers/schedule_batch.py +120 -43
 - sglang/srt/managers/schedule_policy.py +2 -1
 - sglang/srt/managers/scheduler.py +202 -140
 - sglang/srt/managers/tokenizer_manager.py +5 -1
 - sglang/srt/managers/tp_worker.py +111 -1
 - sglang/srt/mem_cache/chunk_cache.py +8 -4
 - sglang/srt/mem_cache/memory_pool.py +77 -4
 - sglang/srt/mem_cache/radix_cache.py +15 -7
 - sglang/srt/model_executor/cuda_graph_runner.py +4 -4
 - sglang/srt/model_executor/forward_batch_info.py +16 -21
 - sglang/srt/model_executor/model_runner.py +60 -1
 - sglang/srt/models/baichuan.py +2 -3
 - sglang/srt/models/chatglm.py +5 -6
 - sglang/srt/models/commandr.py +1 -2
 - sglang/srt/models/dbrx.py +1 -2
 - sglang/srt/models/deepseek.py +4 -5
 - sglang/srt/models/deepseek_v2.py +5 -6
 - sglang/srt/models/exaone.py +1 -2
 - sglang/srt/models/gemma.py +2 -2
 - sglang/srt/models/gemma2.py +5 -5
 - sglang/srt/models/gpt_bigcode.py +5 -5
 - sglang/srt/models/grok.py +1 -2
 - sglang/srt/models/internlm2.py +1 -2
 - sglang/srt/models/llama.py +1 -2
 - sglang/srt/models/llama_classification.py +1 -2
 - sglang/srt/models/llama_reward.py +2 -3
 - sglang/srt/models/llava.py +4 -8
 - sglang/srt/models/llavavid.py +1 -2
 - sglang/srt/models/minicpm.py +1 -2
 - sglang/srt/models/minicpm3.py +5 -6
 - sglang/srt/models/mixtral.py +1 -2
 - sglang/srt/models/mixtral_quant.py +1 -2
 - sglang/srt/models/olmo.py +352 -0
 - sglang/srt/models/olmoe.py +1 -2
 - sglang/srt/models/qwen.py +1 -2
 - sglang/srt/models/qwen2.py +1 -2
 - sglang/srt/models/qwen2_moe.py +4 -5
 - sglang/srt/models/stablelm.py +1 -2
 - sglang/srt/models/torch_native_llama.py +1 -2
 - sglang/srt/models/xverse.py +1 -2
 - sglang/srt/models/xverse_moe.py +4 -5
 - sglang/srt/models/yivl.py +1 -2
 - sglang/srt/openai_api/adapter.py +92 -49
 - sglang/srt/openai_api/protocol.py +10 -2
 - sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
 - sglang/srt/sampling/sampling_batch_info.py +92 -58
 - sglang/srt/sampling/sampling_params.py +2 -0
 - sglang/srt/server.py +116 -17
 - sglang/srt/server_args.py +121 -45
 - sglang/srt/utils.py +11 -3
 - sglang/test/few_shot_gsm8k.py +4 -1
 - sglang/test/few_shot_gsm8k_engine.py +144 -0
 - sglang/test/srt/sampling/penaltylib/utils.py +16 -12
 - sglang/version.py +1 -1
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
 - sglang/srt/layers/attention/flashinfer_utils.py +0 -237
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
 
| 
         @@ -26,7 +26,9 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( 
     | 
|
| 
       26 
26 
     | 
    
         
             
                context_attention_fwd,
         
     | 
| 
       27 
27 
     | 
    
         
             
            )
         
     | 
| 
       28 
28 
     | 
    
         | 
| 
       29 
     | 
    
         
            -
             
     | 
| 
      
 29 
     | 
    
         
            +
            is_cuda_available = torch.cuda.is_available()
         
     | 
| 
      
 30 
     | 
    
         
            +
            if is_cuda_available:
         
     | 
| 
      
 31 
     | 
    
         
            +
                CUDA_CAPABILITY = torch.cuda.get_device_capability()
         
     | 
| 
       30 
32 
     | 
    
         | 
| 
       31 
33 
     | 
    
         | 
| 
       32 
34 
     | 
    
         
             
            @triton.jit
         
     | 
| 
         @@ -286,12 +288,12 @@ def extend_attention_fwd( 
     | 
|
| 
       286 
288 
     | 
    
         
             
                    BLOCK_DPE = 0
         
     | 
| 
       287 
289 
     | 
    
         
             
                BLOCK_DV = triton.next_power_of_2(Lv)
         
     | 
| 
       288 
290 
     | 
    
         | 
| 
       289 
     | 
    
         
            -
                if CUDA_CAPABILITY[0] >= 9:
         
     | 
| 
      
 291 
     | 
    
         
            +
                if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
         
     | 
| 
       290 
292 
     | 
    
         
             
                    if Lq <= 256:
         
     | 
| 
       291 
293 
     | 
    
         
             
                        BLOCK_M, BLOCK_N = (128, 64)
         
     | 
| 
       292 
294 
     | 
    
         
             
                    else:
         
     | 
| 
       293 
295 
     | 
    
         
             
                        BLOCK_M, BLOCK_N = (32, 64)
         
     | 
| 
       294 
     | 
    
         
            -
                elif CUDA_CAPABILITY[0] >= 8:
         
     | 
| 
      
 296 
     | 
    
         
            +
                elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
         
     | 
| 
       295 
297 
     | 
    
         
             
                    if Lq <= 128:
         
     | 
| 
       296 
298 
     | 
    
         
             
                        BLOCK_M, BLOCK_N = (128, 128)
         
     | 
| 
       297 
299 
     | 
    
         
             
                    elif Lq <= 256:
         
     | 
| 
         @@ -24,7 +24,9 @@ import torch 
     | 
|
| 
       24 
24 
     | 
    
         
             
            import triton
         
     | 
| 
       25 
25 
     | 
    
         
             
            import triton.language as tl
         
     | 
| 
       26 
26 
     | 
    
         | 
| 
       27 
     | 
    
         
            -
             
     | 
| 
      
 27 
     | 
    
         
            +
            is_cuda_available = torch.cuda.is_available()
         
     | 
| 
      
 28 
     | 
    
         
            +
            if is_cuda_available:
         
     | 
| 
      
 29 
     | 
    
         
            +
                CUDA_CAPABILITY = torch.cuda.get_device_capability()
         
     | 
| 
       28 
30 
     | 
    
         | 
| 
       29 
31 
     | 
    
         | 
| 
       30 
32 
     | 
    
         
             
            @triton.jit
         
     | 
| 
         @@ -145,7 +147,7 @@ def _fwd_kernel( 
     | 
|
| 
       145 
147 
     | 
    
         | 
| 
       146 
148 
     | 
    
         | 
| 
       147 
149 
     | 
    
         
             
            def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
         
     | 
| 
       148 
     | 
    
         
            -
                if CUDA_CAPABILITY[0] >= 8:
         
     | 
| 
      
 150 
     | 
    
         
            +
                if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
         
     | 
| 
       149 
151 
     | 
    
         
             
                    BLOCK = 128
         
     | 
| 
       150 
152 
     | 
    
         
             
                else:
         
     | 
| 
       151 
153 
     | 
    
         
             
                    BLOCK = 64
         
     | 
    
        sglang/srt/layers/sampler.py
    CHANGED
    
    | 
         @@ -21,6 +21,10 @@ logger = logging.getLogger(__name__) 
     | 
|
| 
       21 
21 
     | 
    
         | 
| 
       22 
22 
     | 
    
         | 
| 
       23 
23 
     | 
    
         
             
            class Sampler(nn.Module):
         
     | 
| 
      
 24 
     | 
    
         
            +
                def __init__(self):
         
     | 
| 
      
 25 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 26 
     | 
    
         
            +
                    self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
       24 
28 
     | 
    
         
             
                def forward(
         
     | 
| 
       25 
29 
     | 
    
         
             
                    self,
         
     | 
| 
       26 
30 
     | 
    
         
             
                    logits: Union[torch.Tensor, LogitsProcessorOutput],
         
     | 
| 
         @@ -36,13 +40,13 @@ class Sampler(nn.Module): 
     | 
|
| 
       36 
40 
     | 
    
         
             
                    logits = None
         
     | 
| 
       37 
41 
     | 
    
         
             
                    del logits
         
     | 
| 
       38 
42 
     | 
    
         | 
| 
       39 
     | 
    
         
            -
                    if torch.any(torch.isnan(probs)):
         
     | 
| 
      
 43 
     | 
    
         
            +
                    if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
         
     | 
| 
       40 
44 
     | 
    
         
             
                        logger.warning("Detected errors during sampling! NaN in the probability.")
         
     | 
| 
       41 
45 
     | 
    
         
             
                        probs = torch.where(
         
     | 
| 
       42 
46 
     | 
    
         
             
                            torch.isnan(probs), torch.full_like(probs, 1e-10), probs
         
     | 
| 
       43 
47 
     | 
    
         
             
                        )
         
     | 
| 
       44 
48 
     | 
    
         | 
| 
       45 
     | 
    
         
            -
                    if sampling_info. 
     | 
| 
      
 49 
     | 
    
         
            +
                    if sampling_info.is_all_greedy:
         
     | 
| 
       46 
50 
     | 
    
         
             
                        # Use torch.argmax if all requests use greedy sampling
         
     | 
| 
       47 
51 
     | 
    
         
             
                        batch_next_token_ids = torch.argmax(probs, -1)
         
     | 
| 
       48 
52 
     | 
    
         
             
                    elif global_server_args_dict["sampling_backend"] == "flashinfer":
         
     | 
| 
         @@ -18,7 +18,7 @@ limitations under the License. 
     | 
|
| 
       18 
18 
     | 
    
         
             
            import dataclasses
         
     | 
| 
       19 
19 
     | 
    
         
             
            import logging
         
     | 
| 
       20 
20 
     | 
    
         
             
            from collections import OrderedDict
         
     | 
| 
       21 
     | 
    
         
            -
            from typing import List
         
     | 
| 
      
 21 
     | 
    
         
            +
            from typing import List, Union
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
       23 
23 
     | 
    
         
             
            import zmq
         
     | 
| 
       24 
24 
     | 
    
         | 
| 
         @@ -29,7 +29,7 @@ from sglang.srt.managers.io_struct import ( 
     | 
|
| 
       29 
29 
     | 
    
         
             
                BatchTokenIDOut,
         
     | 
| 
       30 
30 
     | 
    
         
             
                UpdateWeightReqOutput,
         
     | 
| 
       31 
31 
     | 
    
         
             
            )
         
     | 
| 
       32 
     | 
    
         
            -
            from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
         
     | 
| 
      
 32 
     | 
    
         
            +
            from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
         
     | 
| 
       33 
33 
     | 
    
         
             
            from sglang.srt.server_args import PortArgs, ServerArgs
         
     | 
| 
       34 
34 
     | 
    
         
             
            from sglang.srt.utils import configure_logger, kill_parent_process
         
     | 
| 
       35 
35 
     | 
    
         
             
            from sglang.utils import find_printable_text, get_exception_traceback
         
     | 
| 
         @@ -75,6 +75,21 @@ class DetokenizerManager: 
     | 
|
| 
       75 
75 
     | 
    
         | 
| 
       76 
76 
     | 
    
         
             
                    self.decode_status = LimitedCapacityDict()
         
     | 
| 
       77 
77 
     | 
    
         | 
| 
      
 78 
     | 
    
         
            +
                def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
         
     | 
| 
      
 79 
     | 
    
         
            +
                    if no_stop_trim:
         
     | 
| 
      
 80 
     | 
    
         
            +
                        return output
         
     | 
| 
      
 81 
     | 
    
         
            +
             
     | 
| 
      
 82 
     | 
    
         
            +
                    # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
         
     | 
| 
      
 83 
     | 
    
         
            +
                    if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
         
     | 
| 
      
 84 
     | 
    
         
            +
                        pos = output.find(finished_reason.matched)
         
     | 
| 
      
 85 
     | 
    
         
            +
                        return output[:pos] if pos != -1 else output
         
     | 
| 
      
 86 
     | 
    
         
            +
                    if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
         
     | 
| 
      
 87 
     | 
    
         
            +
                        output, list
         
     | 
| 
      
 88 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 89 
     | 
    
         
            +
                        assert len(output) > 0
         
     | 
| 
      
 90 
     | 
    
         
            +
                        return output[:-1]
         
     | 
| 
      
 91 
     | 
    
         
            +
                    return output
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
       78 
93 
     | 
    
         
             
                def event_loop(self):
         
     | 
| 
       79 
94 
     | 
    
         
             
                    """The event loop that handles requests"""
         
     | 
| 
       80 
95 
     | 
    
         | 
| 
         @@ -122,7 +137,13 @@ class DetokenizerManager: 
     | 
|
| 
       122 
137 
     | 
    
         
             
                                s = self.decode_status[rid]
         
     | 
| 
       123 
138 
     | 
    
         
             
                                s.decode_ids = recv_obj.decode_ids[i]
         
     | 
| 
       124 
139 
     | 
    
         | 
| 
       125 
     | 
    
         
            -
                            read_ids.append( 
     | 
| 
      
 140 
     | 
    
         
            +
                            read_ids.append(
         
     | 
| 
      
 141 
     | 
    
         
            +
                                self.trim_eos(
         
     | 
| 
      
 142 
     | 
    
         
            +
                                    s.decode_ids[s.surr_offset :],
         
     | 
| 
      
 143 
     | 
    
         
            +
                                    recv_obj.finished_reason[i],
         
     | 
| 
      
 144 
     | 
    
         
            +
                                    recv_obj.no_stop_trim[i],
         
     | 
| 
      
 145 
     | 
    
         
            +
                                )
         
     | 
| 
      
 146 
     | 
    
         
            +
                            )
         
     | 
| 
       126 
147 
     | 
    
         
             
                            surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])
         
     | 
| 
       127 
148 
     | 
    
         | 
| 
       128 
149 
     | 
    
         
             
                        # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
         
     | 
| 
         @@ -152,13 +173,13 @@ class DetokenizerManager: 
     | 
|
| 
       152 
173 
     | 
    
         
             
                                else:
         
     | 
| 
       153 
174 
     | 
    
         
             
                                    new_text = find_printable_text(new_text)
         
     | 
| 
       154 
175 
     | 
    
         | 
| 
       155 
     | 
    
         
            -
                            output_strs.append( 
     | 
| 
       156 
     | 
    
         
            -
             
     | 
| 
       157 
     | 
    
         
            -
             
     | 
| 
       158 
     | 
    
         
            -
             
     | 
| 
       159 
     | 
    
         
            -
             
     | 
| 
       160 
     | 
    
         
            -
                                 
     | 
| 
       161 
     | 
    
         
            -
             
     | 
| 
      
 176 
     | 
    
         
            +
                            output_strs.append(
         
     | 
| 
      
 177 
     | 
    
         
            +
                                self.trim_eos(
         
     | 
| 
      
 178 
     | 
    
         
            +
                                    s.decoded_text + new_text,
         
     | 
| 
      
 179 
     | 
    
         
            +
                                    recv_obj.finished_reason[i],
         
     | 
| 
      
 180 
     | 
    
         
            +
                                    recv_obj.no_stop_trim[i],
         
     | 
| 
      
 181 
     | 
    
         
            +
                                )
         
     | 
| 
      
 182 
     | 
    
         
            +
                            )
         
     | 
| 
       162 
183 
     | 
    
         | 
| 
       163 
184 
     | 
    
         
             
                        self.send_to_tokenizer.send_pyobj(
         
     | 
| 
       164 
185 
     | 
    
         
             
                            BatchStrOut(
         
     | 
    
        sglang/srt/managers/io_struct.py
    CHANGED
    
    | 
         @@ -56,6 +56,9 @@ class GenerateReqInput: 
     | 
|
| 
       56 
56 
     | 
    
         
             
                # LoRA related
         
     | 
| 
       57 
57 
     | 
    
         
             
                lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
         
     | 
| 
       58 
58 
     | 
    
         | 
| 
      
 59 
     | 
    
         
            +
                # Whether it is a single request or a batch request
         
     | 
| 
      
 60 
     | 
    
         
            +
                is_single: bool = True
         
     | 
| 
      
 61 
     | 
    
         
            +
             
     | 
| 
       59 
62 
     | 
    
         
             
                def post_init(self):
         
     | 
| 
       60 
63 
     | 
    
         
             
                    if (self.text is None and self.input_ids is None) or (
         
     | 
| 
       61 
64 
     | 
    
         
             
                        self.text is not None and self.input_ids is not None
         
     | 
| 
         @@ -295,6 +298,7 @@ class BatchTokenIDOut: 
     | 
|
| 
       295 
298 
     | 
    
         
             
                spaces_between_special_tokens: List[bool]
         
     | 
| 
       296 
299 
     | 
    
         
             
                meta_info: List[Dict]
         
     | 
| 
       297 
300 
     | 
    
         
             
                finished_reason: List[BaseFinishReason]
         
     | 
| 
      
 301 
     | 
    
         
            +
                no_stop_trim: List[bool]
         
     | 
| 
       298 
302 
     | 
    
         | 
| 
       299 
303 
     | 
    
         | 
| 
       300 
304 
     | 
    
         
             
            @dataclass
         
     | 
| 
         @@ -53,6 +53,7 @@ global_server_args_dict = { 
     | 
|
| 
       53 
53 
     | 
    
         
             
                "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
         
     | 
| 
       54 
54 
     | 
    
         
             
                "disable_mla": ServerArgs.disable_mla,
         
     | 
| 
       55 
55 
     | 
    
         
             
                "torchao_config": ServerArgs.torchao_config,
         
     | 
| 
      
 56 
     | 
    
         
            +
                "disable_nan_detection": ServerArgs.disable_nan_detection,
         
     | 
| 
       56 
57 
     | 
    
         
             
            }
         
     | 
| 
       57 
58 
     | 
    
         | 
| 
       58 
59 
     | 
    
         | 
| 
         @@ -196,6 +197,9 @@ class Req: 
     | 
|
| 
       196 
197 
     | 
    
         
             
                    # this does not include the jump forward tokens.
         
     | 
| 
       197 
198 
     | 
    
         
             
                    self.completion_tokens_wo_jump_forward = 0
         
     | 
| 
       198 
199 
     | 
    
         | 
| 
      
 200 
     | 
    
         
            +
                    # The number of cached tokens, that were already cached in the KV store
         
     | 
| 
      
 201 
     | 
    
         
            +
                    self.cached_tokens = 0
         
     | 
| 
      
 202 
     | 
    
         
            +
             
     | 
| 
       199 
203 
     | 
    
         
             
                    # For vision inputs
         
     | 
| 
       200 
204 
     | 
    
         
             
                    self.image_inputs: Optional[ImageInputs] = None
         
     | 
| 
       201 
205 
     | 
    
         | 
| 
         @@ -203,6 +207,7 @@ class Req: 
     | 
|
| 
       203 
207 
     | 
    
         
             
                    self.prefix_indices = []
         
     | 
| 
       204 
208 
     | 
    
         
             
                    self.extend_input_len = 0
         
     | 
| 
       205 
209 
     | 
    
         
             
                    self.last_node = None
         
     | 
| 
      
 210 
     | 
    
         
            +
                    self.is_inflight_req = 0
         
     | 
| 
       206 
211 
     | 
    
         | 
| 
       207 
212 
     | 
    
         
             
                    # Logprobs (arguments)
         
     | 
| 
       208 
213 
     | 
    
         
             
                    self.return_logprob = False
         
     | 
| 
         @@ -391,25 +396,30 @@ class Req: 
     | 
|
| 
       391 
396 
     | 
    
         
             
                    return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
         
     | 
| 
       392 
397 
     | 
    
         | 
| 
       393 
398 
     | 
    
         | 
| 
      
 399 
     | 
    
         
            +
            bid = 0
         
     | 
| 
      
 400 
     | 
    
         
            +
             
     | 
| 
      
 401 
     | 
    
         
            +
             
     | 
| 
       394 
402 
     | 
    
         
             
            @dataclass
         
     | 
| 
       395 
403 
     | 
    
         
             
            class ScheduleBatch:
         
     | 
| 
       396 
404 
     | 
    
         
             
                """Store all inforamtion of a batch."""
         
     | 
| 
       397 
405 
     | 
    
         | 
| 
       398 
406 
     | 
    
         
             
                # Request, memory pool, and cache
         
     | 
| 
       399 
407 
     | 
    
         
             
                reqs: List[Req]
         
     | 
| 
       400 
     | 
    
         
            -
                req_to_token_pool: ReqToTokenPool
         
     | 
| 
       401 
     | 
    
         
            -
                token_to_kv_pool: BaseTokenToKVPool
         
     | 
| 
       402 
     | 
    
         
            -
                tree_cache: BasePrefixCache
         
     | 
| 
      
 408 
     | 
    
         
            +
                req_to_token_pool: ReqToTokenPool = None
         
     | 
| 
      
 409 
     | 
    
         
            +
                token_to_kv_pool: BaseTokenToKVPool = None
         
     | 
| 
      
 410 
     | 
    
         
            +
                tree_cache: BasePrefixCache = None
         
     | 
| 
       403 
411 
     | 
    
         | 
| 
       404 
412 
     | 
    
         
             
                forward_mode: ForwardMode = None
         
     | 
| 
       405 
413 
     | 
    
         
             
                sampling_info: SamplingBatchInfo = None
         
     | 
| 
       406 
414 
     | 
    
         | 
| 
       407 
415 
     | 
    
         
             
                # Batched arguments to model runner
         
     | 
| 
       408 
     | 
    
         
            -
                input_ids:  
     | 
| 
       409 
     | 
    
         
            -
                req_pool_indices:  
     | 
| 
       410 
     | 
    
         
            -
                seq_lens:  
     | 
| 
      
 416 
     | 
    
         
            +
                input_ids: torch.Tensor = None
         
     | 
| 
      
 417 
     | 
    
         
            +
                req_pool_indices: torch.Tensor = None
         
     | 
| 
      
 418 
     | 
    
         
            +
                seq_lens: torch.Tensor = None
         
     | 
| 
       411 
419 
     | 
    
         
             
                out_cache_loc: torch.Tensor = None
         
     | 
| 
       412 
420 
     | 
    
         | 
| 
      
 421 
     | 
    
         
            +
                output_ids: torch.Tensor = None
         
     | 
| 
      
 422 
     | 
    
         
            +
             
     | 
| 
       413 
423 
     | 
    
         
             
                # For processing logprobs
         
     | 
| 
       414 
424 
     | 
    
         
             
                return_logprob: bool = False
         
     | 
| 
       415 
425 
     | 
    
         
             
                top_logprobs_nums: Optional[List[int]] = None
         
     | 
| 
         @@ -419,6 +429,7 @@ class ScheduleBatch: 
     | 
|
| 
       419 
429 
     | 
    
         
             
                extend_lens: List[int] = None
         
     | 
| 
       420 
430 
     | 
    
         
             
                extend_num_tokens: int = None
         
     | 
| 
       421 
431 
     | 
    
         
             
                running_bs: int = None
         
     | 
| 
      
 432 
     | 
    
         
            +
                decoding_reqs: List[Req] = None
         
     | 
| 
       422 
433 
     | 
    
         | 
| 
       423 
434 
     | 
    
         
             
                # Stream
         
     | 
| 
       424 
435 
     | 
    
         
             
                has_stream: bool = False
         
     | 
| 
         @@ -492,17 +503,24 @@ class ScheduleBatch: 
     | 
|
| 
       492 
503 
     | 
    
         | 
| 
       493 
504 
     | 
    
         
             
                    pt = 0
         
     | 
| 
       494 
505 
     | 
    
         
             
                    for i, req in enumerate(reqs):
         
     | 
| 
      
 506 
     | 
    
         
            +
                        already_computed = (
         
     | 
| 
      
 507 
     | 
    
         
            +
                            req.extend_logprob_start_len + 1 + req.cached_tokens
         
     | 
| 
      
 508 
     | 
    
         
            +
                            if req.extend_logprob_start_len > 0
         
     | 
| 
      
 509 
     | 
    
         
            +
                            else 0
         
     | 
| 
      
 510 
     | 
    
         
            +
                        )
         
     | 
| 
      
 511 
     | 
    
         
            +
                        req.cached_tokens += len(req.prefix_indices) - already_computed
         
     | 
| 
      
 512 
     | 
    
         
            +
             
     | 
| 
       495 
513 
     | 
    
         
             
                        req.req_pool_idx = req_pool_indices[i]
         
     | 
| 
       496 
514 
     | 
    
         
             
                        pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
         
     | 
| 
       497 
515 
     | 
    
         
             
                        seq_lens.append(seq_len)
         
     | 
| 
       498 
516 
     | 
    
         
             
                        assert seq_len - pre_len == req.extend_input_len
         
     | 
| 
       499 
517 
     | 
    
         | 
| 
       500 
518 
     | 
    
         
             
                        if pre_len > 0:
         
     | 
| 
       501 
     | 
    
         
            -
                            self.req_to_token_pool.req_to_token[req.req_pool_idx] 
     | 
| 
       502 
     | 
    
         
            -
                                 
     | 
| 
       503 
     | 
    
         
            -
                             
     | 
| 
      
 519 
     | 
    
         
            +
                            self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
         
     | 
| 
      
 520 
     | 
    
         
            +
                                req.prefix_indices
         
     | 
| 
      
 521 
     | 
    
         
            +
                            )
         
     | 
| 
       504 
522 
     | 
    
         | 
| 
       505 
     | 
    
         
            -
                        self.req_to_token_pool.req_to_token[req.req_pool_idx 
     | 
| 
      
 523 
     | 
    
         
            +
                        self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
         
     | 
| 
       506 
524 
     | 
    
         
             
                            out_cache_loc[pt : pt + req.extend_input_len]
         
     | 
| 
       507 
525 
     | 
    
         
             
                        )
         
     | 
| 
       508 
526 
     | 
    
         | 
| 
         @@ -518,10 +536,15 @@ class ScheduleBatch: 
     | 
|
| 
       518 
536 
     | 
    
         
             
                        pt += req.extend_input_len
         
     | 
| 
       519 
537 
     | 
    
         | 
| 
       520 
538 
     | 
    
         
             
                    # Set fields
         
     | 
| 
       521 
     | 
    
         
            -
                     
     | 
| 
       522 
     | 
    
         
            -
                        self. 
     | 
| 
       523 
     | 
    
         
            -
             
     | 
| 
       524 
     | 
    
         
            -
             
     | 
| 
      
 539 
     | 
    
         
            +
                    self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
         
     | 
| 
      
 540 
     | 
    
         
            +
                        self.device, non_blocking=True
         
     | 
| 
      
 541 
     | 
    
         
            +
                    )
         
     | 
| 
      
 542 
     | 
    
         
            +
                    self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to(
         
     | 
| 
      
 543 
     | 
    
         
            +
                        self.device, non_blocking=True
         
     | 
| 
      
 544 
     | 
    
         
            +
                    )
         
     | 
| 
      
 545 
     | 
    
         
            +
                    self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
         
     | 
| 
      
 546 
     | 
    
         
            +
                        self.device, non_blocking=True
         
     | 
| 
      
 547 
     | 
    
         
            +
                    )
         
     | 
| 
       525 
548 
     | 
    
         | 
| 
       526 
549 
     | 
    
         
             
                    self.extend_num_tokens = extend_num_tokens
         
     | 
| 
       527 
550 
     | 
    
         
             
                    self.out_cache_loc = out_cache_loc
         
     | 
| 
         @@ -531,7 +554,9 @@ class ScheduleBatch: 
     | 
|
| 
       531 
554 
     | 
    
         
             
                    self.extend_lens = [r.extend_input_len for r in reqs]
         
     | 
| 
       532 
555 
     | 
    
         
             
                    self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
         
     | 
| 
       533 
556 
     | 
    
         | 
| 
       534 
     | 
    
         
            -
                    self.sampling_info = SamplingBatchInfo.from_schedule_batch( 
     | 
| 
      
 557 
     | 
    
         
            +
                    self.sampling_info = SamplingBatchInfo.from_schedule_batch(
         
     | 
| 
      
 558 
     | 
    
         
            +
                        self, vocab_size, global_server_args_dict["disable_penalizer"]
         
     | 
| 
      
 559 
     | 
    
         
            +
                    )
         
     | 
| 
       535 
560 
     | 
    
         | 
| 
       536 
561 
     | 
    
         
             
                def mix_with_running(self, running_batch: "ScheduleBatch"):
         
     | 
| 
       537 
562 
     | 
    
         
             
                    self.forward_mode = ForwardMode.MIXED
         
     | 
| 
         @@ -586,9 +611,11 @@ class ScheduleBatch: 
     | 
|
| 
       586 
611 
     | 
    
         | 
| 
       587 
612 
     | 
    
         
             
                    retracted_reqs = []
         
     | 
| 
       588 
613 
     | 
    
         
             
                    seq_lens_cpu = self.seq_lens.cpu().numpy()
         
     | 
| 
      
 614 
     | 
    
         
            +
                    first_iter = True
         
     | 
| 
       589 
615 
     | 
    
         
             
                    while (
         
     | 
| 
       590 
616 
     | 
    
         
             
                        self.token_to_kv_pool.available_size()
         
     | 
| 
       591 
617 
     | 
    
         
             
                        < len(sorted_indices) * global_config.retract_decode_steps
         
     | 
| 
      
 618 
     | 
    
         
            +
                        or first_iter
         
     | 
| 
       592 
619 
     | 
    
         
             
                    ):
         
     | 
| 
       593 
620 
     | 
    
         
             
                        if len(sorted_indices) == 1:
         
     | 
| 
       594 
621 
     | 
    
         
             
                            # Corner case: only one request left
         
     | 
| 
         @@ -597,6 +624,7 @@ class ScheduleBatch: 
     | 
|
| 
       597 
624 
     | 
    
         
             
                            ), "No space left for only one request"
         
     | 
| 
       598 
625 
     | 
    
         
             
                            break
         
     | 
| 
       599 
626 
     | 
    
         | 
| 
      
 627 
     | 
    
         
            +
                        first_iter = False
         
     | 
| 
       600 
628 
     | 
    
         
             
                        idx = sorted_indices.pop()
         
     | 
| 
       601 
629 
     | 
    
         
             
                        req = self.reqs[idx]
         
     | 
| 
       602 
630 
     | 
    
         
             
                        retracted_reqs.append(req)
         
     | 
| 
         @@ -637,7 +665,7 @@ class ScheduleBatch: 
     | 
|
| 
       637 
665 
     | 
    
         
             
                        req.last_update_decode_tokens = 0
         
     | 
| 
       638 
666 
     | 
    
         
             
                        req.logprob_start_len = 10**9
         
     | 
| 
       639 
667 
     | 
    
         | 
| 
       640 
     | 
    
         
            -
                    self.filter_batch(sorted_indices)
         
     | 
| 
      
 668 
     | 
    
         
            +
                    self.filter_batch(keep_indices=sorted_indices)
         
     | 
| 
       641 
669 
     | 
    
         | 
| 
       642 
670 
     | 
    
         
             
                    # Reqs in batch are filtered
         
     | 
| 
       643 
671 
     | 
    
         
             
                    total_decoded_tokens = sum(len(r.output_ids) for r in self.reqs)
         
     | 
| 
         @@ -652,7 +680,7 @@ class ScheduleBatch: 
     | 
|
| 
       652 
680 
     | 
    
         | 
| 
       653 
681 
     | 
    
         
             
                def check_for_jump_forward(self, pad_input_ids_func):
         
     | 
| 
       654 
682 
     | 
    
         
             
                    jump_forward_reqs = []
         
     | 
| 
       655 
     | 
    
         
            -
                     
     | 
| 
      
 683 
     | 
    
         
            +
                    keep_indices = set(i for i in range(len(self.reqs)))
         
     | 
| 
       656 
684 
     | 
    
         | 
| 
       657 
685 
     | 
    
         
             
                    for i, req in enumerate(self.reqs):
         
     | 
| 
       658 
686 
     | 
    
         
             
                        if req.jump_forward_map is not None:
         
     | 
| 
         @@ -712,63 +740,71 @@ class ScheduleBatch: 
     | 
|
| 
       712 
740 
     | 
    
         
             
                                    )
         
     | 
| 
       713 
741 
     | 
    
         | 
| 
       714 
742 
     | 
    
         
             
                                jump_forward_reqs.append(req)
         
     | 
| 
       715 
     | 
    
         
            -
                                 
     | 
| 
      
 743 
     | 
    
         
            +
                                keep_indices.remove(i)
         
     | 
| 
       716 
744 
     | 
    
         | 
| 
       717 
     | 
    
         
            -
                    self.filter_batch( 
     | 
| 
      
 745 
     | 
    
         
            +
                    self.filter_batch(keep_indices=list(keep_indices))
         
     | 
| 
       718 
746 
     | 
    
         | 
| 
       719 
747 
     | 
    
         
             
                    return jump_forward_reqs
         
     | 
| 
       720 
748 
     | 
    
         | 
| 
       721 
     | 
    
         
            -
                def prepare_for_decode(self 
     | 
| 
      
 749 
     | 
    
         
            +
                def prepare_for_decode(self):
         
     | 
| 
       722 
750 
     | 
    
         
             
                    self.forward_mode = ForwardMode.DECODE
         
     | 
| 
       723 
751 
     | 
    
         | 
| 
       724 
     | 
    
         
            -
                     
     | 
| 
       725 
     | 
    
         
            -
             
     | 
| 
       726 
     | 
    
         
            -
             
     | 
| 
       727 
     | 
    
         
            -
             
     | 
| 
       728 
     | 
    
         
            -
             
     | 
| 
       729 
     | 
    
         
            -
             
     | 
| 
       730 
     | 
    
         
            -
                    self.input_ids = torch.tensor(
         
     | 
| 
       731 
     | 
    
         
            -
                        input_ids, dtype=torch.int32, device=self.seq_lens.device
         
     | 
| 
       732 
     | 
    
         
            -
                    )
         
     | 
| 
       733 
     | 
    
         
            -
                    self.seq_lens.add_(1)
         
     | 
| 
      
 752 
     | 
    
         
            +
                    self.input_ids = self.output_ids
         
     | 
| 
      
 753 
     | 
    
         
            +
                    self.output_ids = None
         
     | 
| 
      
 754 
     | 
    
         
            +
                    if self.sampling_info.penalizer_orchestrator:
         
     | 
| 
      
 755 
     | 
    
         
            +
                        self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
         
     | 
| 
      
 756 
     | 
    
         
            +
                            self.input_ids
         
     | 
| 
      
 757 
     | 
    
         
            +
                        )
         
     | 
| 
       734 
758 
     | 
    
         | 
| 
       735 
759 
     | 
    
         
             
                    # Alloc mem
         
     | 
| 
       736 
760 
     | 
    
         
             
                    bs = len(self.reqs)
         
     | 
| 
       737 
761 
     | 
    
         
             
                    self.out_cache_loc = self.alloc_token_slots(bs)
         
     | 
| 
       738 
762 
     | 
    
         | 
| 
       739 
     | 
    
         
            -
                    self.req_to_token_pool.req_to_token[
         
     | 
| 
       740 
     | 
    
         
            -
                        self. 
     | 
| 
       741 
     | 
    
         
            -
                     
     | 
| 
      
 763 
     | 
    
         
            +
                    self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
         
     | 
| 
      
 764 
     | 
    
         
            +
                        self.out_cache_loc
         
     | 
| 
      
 765 
     | 
    
         
            +
                    )
         
     | 
| 
      
 766 
     | 
    
         
            +
                    self.seq_lens.add_(1)
         
     | 
| 
       742 
767 
     | 
    
         | 
| 
       743 
     | 
    
         
            -
                def filter_batch( 
     | 
| 
       744 
     | 
    
         
            -
                     
     | 
| 
      
 768 
     | 
    
         
            +
                def filter_batch(
         
     | 
| 
      
 769 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 770 
     | 
    
         
            +
                    current_inflight_req: Optional[Req] = None,
         
     | 
| 
      
 771 
     | 
    
         
            +
                    keep_indices: Optional[List[int]] = None,
         
     | 
| 
      
 772 
     | 
    
         
            +
                ):
         
     | 
| 
      
 773 
     | 
    
         
            +
                    if keep_indices is None:
         
     | 
| 
      
 774 
     | 
    
         
            +
                        keep_indices = [
         
     | 
| 
      
 775 
     | 
    
         
            +
                            i
         
     | 
| 
      
 776 
     | 
    
         
            +
                            for i in range(len(self.reqs))
         
     | 
| 
      
 777 
     | 
    
         
            +
                            if not self.reqs[i].finished()
         
     | 
| 
      
 778 
     | 
    
         
            +
                            and self.reqs[i] is not current_inflight_req
         
     | 
| 
      
 779 
     | 
    
         
            +
                        ]
         
     | 
| 
      
 780 
     | 
    
         
            +
             
     | 
| 
      
 781 
     | 
    
         
            +
                    if keep_indices is None or len(keep_indices) == 0:
         
     | 
| 
       745 
782 
     | 
    
         
             
                        # Filter out all requests
         
     | 
| 
       746 
783 
     | 
    
         
             
                        self.reqs = []
         
     | 
| 
       747 
784 
     | 
    
         
             
                        return
         
     | 
| 
       748 
785 
     | 
    
         | 
| 
       749 
     | 
    
         
            -
                    if len( 
     | 
| 
      
 786 
     | 
    
         
            +
                    if len(keep_indices) == len(self.reqs):
         
     | 
| 
       750 
787 
     | 
    
         
             
                        # No need to filter
         
     | 
| 
       751 
788 
     | 
    
         
             
                        return
         
     | 
| 
       752 
789 
     | 
    
         | 
| 
       753 
     | 
    
         
            -
                    self.reqs = [self.reqs[i] for i in  
     | 
| 
       754 
     | 
    
         
            -
                    new_indices = torch.tensor(
         
     | 
| 
       755 
     | 
    
         
            -
                         
     | 
| 
      
 790 
     | 
    
         
            +
                    self.reqs = [self.reqs[i] for i in keep_indices]
         
     | 
| 
      
 791 
     | 
    
         
            +
                    new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
         
     | 
| 
      
 792 
     | 
    
         
            +
                        self.device, non_blocking=True
         
     | 
| 
       756 
793 
     | 
    
         
             
                    )
         
     | 
| 
       757 
794 
     | 
    
         
             
                    self.req_pool_indices = self.req_pool_indices[new_indices]
         
     | 
| 
       758 
795 
     | 
    
         
             
                    self.seq_lens = self.seq_lens[new_indices]
         
     | 
| 
       759 
796 
     | 
    
         
             
                    self.out_cache_loc = None
         
     | 
| 
      
 797 
     | 
    
         
            +
                    self.output_ids = self.output_ids[new_indices]
         
     | 
| 
       760 
798 
     | 
    
         
             
                    self.return_logprob = any(req.return_logprob for req in self.reqs)
         
     | 
| 
       761 
799 
     | 
    
         
             
                    if self.return_logprob:
         
     | 
| 
       762 
     | 
    
         
            -
                        self.top_logprobs_nums = [
         
     | 
| 
       763 
     | 
    
         
            -
                            self.top_logprobs_nums[i] for i in unfinished_indices
         
     | 
| 
       764 
     | 
    
         
            -
                        ]
         
     | 
| 
      
 800 
     | 
    
         
            +
                        self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in keep_indices]
         
     | 
| 
       765 
801 
     | 
    
         
             
                    else:
         
     | 
| 
       766 
802 
     | 
    
         
             
                        self.top_logprobs_nums = None
         
     | 
| 
       767 
803 
     | 
    
         | 
| 
       768 
804 
     | 
    
         
             
                    self.has_stream = any(req.stream for req in self.reqs)
         
     | 
| 
       769 
805 
     | 
    
         
             
                    self.has_regex = any(req.regex_fsm for req in self.reqs)
         
     | 
| 
       770 
806 
     | 
    
         | 
| 
       771 
     | 
    
         
            -
                    self.sampling_info.filter_batch( 
     | 
| 
      
 807 
     | 
    
         
            +
                    self.sampling_info.filter_batch(keep_indices, new_indices)
         
     | 
| 
       772 
808 
     | 
    
         | 
| 
       773 
809 
     | 
    
         
             
                def merge_batch(self, other: "ScheduleBatch"):
         
     | 
| 
       774 
810 
     | 
    
         
             
                    # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
         
     | 
| 
         @@ -781,6 +817,8 @@ class ScheduleBatch: 
     | 
|
| 
       781 
817 
     | 
    
         
             
                    )
         
     | 
| 
       782 
818 
     | 
    
         
             
                    self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
         
     | 
| 
       783 
819 
     | 
    
         
             
                    self.out_cache_loc = None
         
     | 
| 
      
 820 
     | 
    
         
            +
                    if self.output_ids is not None:
         
     | 
| 
      
 821 
     | 
    
         
            +
                        self.output_ids = torch.concat([self.output_ids, other.output_ids])
         
     | 
| 
       784 
822 
     | 
    
         
             
                    if self.return_logprob and other.return_logprob:
         
     | 
| 
       785 
823 
     | 
    
         
             
                        self.top_logprobs_nums.extend(other.top_logprobs_nums)
         
     | 
| 
       786 
824 
     | 
    
         
             
                    elif self.return_logprob:
         
     | 
| 
         @@ -813,7 +851,11 @@ class ScheduleBatch: 
     | 
|
| 
       813 
851 
     | 
    
         
             
                    else:
         
     | 
| 
       814 
852 
     | 
    
         
             
                        self.sampling_info.regex_fsms = None
         
     | 
| 
       815 
853 
     | 
    
         | 
| 
      
 854 
     | 
    
         
            +
                    global bid
         
     | 
| 
      
 855 
     | 
    
         
            +
                    bid += 1
         
     | 
| 
      
 856 
     | 
    
         
            +
             
     | 
| 
       816 
857 
     | 
    
         
             
                    return ModelWorkerBatch(
         
     | 
| 
      
 858 
     | 
    
         
            +
                        bid=bid,
         
     | 
| 
       817 
859 
     | 
    
         
             
                        forward_mode=self.forward_mode,
         
     | 
| 
       818 
860 
     | 
    
         
             
                        input_ids=self.input_ids,
         
     | 
| 
       819 
861 
     | 
    
         
             
                        req_pool_indices=self.req_pool_indices,
         
     | 
| 
         @@ -829,9 +871,26 @@ class ScheduleBatch: 
     | 
|
| 
       829 
871 
     | 
    
         
             
                        sampling_info=self.sampling_info,
         
     | 
| 
       830 
872 
     | 
    
         
             
                    )
         
     | 
| 
       831 
873 
     | 
    
         | 
| 
      
 874 
     | 
    
         
            +
                def copy(self):
         
     | 
| 
      
 875 
     | 
    
         
            +
                    return ScheduleBatch(
         
     | 
| 
      
 876 
     | 
    
         
            +
                        reqs=self.reqs,
         
     | 
| 
      
 877 
     | 
    
         
            +
                        forward_mode=self.forward_mode,
         
     | 
| 
      
 878 
     | 
    
         
            +
                        out_cache_loc=self.out_cache_loc,
         
     | 
| 
      
 879 
     | 
    
         
            +
                        return_logprob=self.return_logprob,
         
     | 
| 
      
 880 
     | 
    
         
            +
                        decoding_reqs=self.decoding_reqs,
         
     | 
| 
      
 881 
     | 
    
         
            +
                    )
         
     | 
| 
      
 882 
     | 
    
         
            +
             
     | 
| 
      
 883 
     | 
    
         
            +
                def __str__(self):
         
     | 
| 
      
 884 
     | 
    
         
            +
                    return (
         
     | 
| 
      
 885 
     | 
    
         
            +
                        f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
         
     | 
| 
      
 886 
     | 
    
         
            +
                        f"#req={(len(self.reqs))})"
         
     | 
| 
      
 887 
     | 
    
         
            +
                    )
         
     | 
| 
      
 888 
     | 
    
         
            +
             
     | 
| 
       832 
889 
     | 
    
         | 
| 
       833 
890 
     | 
    
         
             
            @dataclass
         
     | 
| 
       834 
891 
     | 
    
         
             
            class ModelWorkerBatch:
         
     | 
| 
      
 892 
     | 
    
         
            +
                # The batch id
         
     | 
| 
      
 893 
     | 
    
         
            +
                bid: int
         
     | 
| 
       835 
894 
     | 
    
         
             
                # The forward mode
         
     | 
| 
       836 
895 
     | 
    
         
             
                forward_mode: ForwardMode
         
     | 
| 
       837 
896 
     | 
    
         
             
                # The input ids
         
     | 
| 
         @@ -860,3 +919,21 @@ class ModelWorkerBatch: 
     | 
|
| 
       860 
919 
     | 
    
         | 
| 
       861 
920 
     | 
    
         
             
                # Sampling info
         
     | 
| 
       862 
921 
     | 
    
         
             
                sampling_info: SamplingBatchInfo
         
     | 
| 
      
 922 
     | 
    
         
            +
             
     | 
| 
      
 923 
     | 
    
         
            +
                def copy(self):
         
     | 
| 
      
 924 
     | 
    
         
            +
                    return ModelWorkerBatch(
         
     | 
| 
      
 925 
     | 
    
         
            +
                        bid=self.bid,
         
     | 
| 
      
 926 
     | 
    
         
            +
                        forward_mode=self.forward_mode,
         
     | 
| 
      
 927 
     | 
    
         
            +
                        input_ids=self.input_ids.clone(),
         
     | 
| 
      
 928 
     | 
    
         
            +
                        req_pool_indices=self.req_pool_indices,
         
     | 
| 
      
 929 
     | 
    
         
            +
                        seq_lens=self.seq_lens.clone(),
         
     | 
| 
      
 930 
     | 
    
         
            +
                        out_cache_loc=self.out_cache_loc,
         
     | 
| 
      
 931 
     | 
    
         
            +
                        return_logprob=self.return_logprob,
         
     | 
| 
      
 932 
     | 
    
         
            +
                        top_logprobs_nums=self.top_logprobs_nums,
         
     | 
| 
      
 933 
     | 
    
         
            +
                        extend_seq_lens=self.extend_seq_lens,
         
     | 
| 
      
 934 
     | 
    
         
            +
                        extend_prefix_lens=self.extend_prefix_lens,
         
     | 
| 
      
 935 
     | 
    
         
            +
                        extend_logprob_start_lens=self.extend_logprob_start_lens,
         
     | 
| 
      
 936 
     | 
    
         
            +
                        image_inputs=self.image_inputs,
         
     | 
| 
      
 937 
     | 
    
         
            +
                        lora_paths=self.lora_paths,
         
     | 
| 
      
 938 
     | 
    
         
            +
                        sampling_info=self.sampling_info.copy(),
         
     | 
| 
      
 939 
     | 
    
         
            +
                    )
         
     | 
| 
         @@ -45,12 +45,13 @@ class SchedulePolicy: 
     | 
|
| 
       45 
45 
     | 
    
         
             
                def calc_priority(self, waiting_queue: List[Req]):
         
     | 
| 
       46 
46 
     | 
    
         
             
                    # Compute matched prefix length
         
     | 
| 
       47 
47 
     | 
    
         
             
                    prefix_computed = False
         
     | 
| 
       48 
     | 
    
         
            -
                    if self.policy  
     | 
| 
      
 48 
     | 
    
         
            +
                    if self.policy == "lpm" or self.policy == "dfs-weight":
         
     | 
| 
       49 
49 
     | 
    
         
             
                        for r in waiting_queue:
         
     | 
| 
       50 
50 
     | 
    
         
             
                            # NOTE: the prefix_indices must always be aligned with last_node
         
     | 
| 
       51 
51 
     | 
    
         
             
                            r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
         
     | 
| 
       52 
52 
     | 
    
         
             
                                rid=r.rid, key=r.adjust_max_prefix_ids()
         
     | 
| 
       53 
53 
     | 
    
         
             
                            )
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
       54 
55 
     | 
    
         
             
                        prefix_computed = True
         
     | 
| 
       55 
56 
     | 
    
         | 
| 
       56 
57 
     | 
    
         
             
                    if self.policy == "lpm":
         
     |