sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
 - sglang/lang/chat_template.py +17 -0
 - sglang/launch_server_llavavid.py +1 -1
 - sglang/srt/configs/__init__.py +3 -0
 - sglang/srt/configs/model_config.py +27 -2
 - sglang/srt/configs/qwen2vl.py +133 -0
 - sglang/srt/constrained/fsm_cache.py +10 -3
 - sglang/srt/conversation.py +27 -0
 - sglang/srt/hf_transformers_utils.py +16 -1
 - sglang/srt/layers/attention/__init__.py +16 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
 - sglang/srt/layers/attention/flashinfer_backend.py +174 -54
 - sglang/srt/layers/attention/triton_backend.py +22 -6
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
 - sglang/srt/layers/linear.py +89 -63
 - sglang/srt/layers/logits_processor.py +5 -5
 - sglang/srt/layers/rotary_embedding.py +112 -0
 - sglang/srt/layers/sampler.py +51 -39
 - sglang/srt/lora/lora.py +3 -1
 - sglang/srt/managers/data_parallel_controller.py +1 -1
 - sglang/srt/managers/detokenizer_manager.py +4 -0
 - sglang/srt/managers/image_processor.py +186 -13
 - sglang/srt/managers/io_struct.py +10 -0
 - sglang/srt/managers/schedule_batch.py +238 -68
 - sglang/srt/managers/scheduler.py +69 -50
 - sglang/srt/managers/tokenizer_manager.py +24 -4
 - sglang/srt/managers/tp_worker.py +26 -111
 - sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
 - sglang/srt/mem_cache/memory_pool.py +56 -10
 - sglang/srt/mem_cache/radix_cache.py +4 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +87 -28
 - sglang/srt/model_executor/forward_batch_info.py +83 -3
 - sglang/srt/model_executor/model_runner.py +32 -11
 - sglang/srt/models/chatglm.py +3 -3
 - sglang/srt/models/deepseek_v2.py +2 -2
 - sglang/srt/models/mllama.py +1004 -0
 - sglang/srt/models/qwen2_vl.py +724 -0
 - sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
 - sglang/srt/sampling/sampling_batch_info.py +13 -3
 - sglang/srt/sampling/sampling_params.py +5 -7
 - sglang/srt/server.py +12 -0
 - sglang/srt/server_args.py +10 -0
 - sglang/srt/utils.py +22 -0
 - sglang/test/run_eval.py +2 -0
 - sglang/test/runners.py +20 -1
 - sglang/test/srt/sampling/penaltylib/utils.py +1 -0
 - sglang/test/test_utils.py +100 -3
 - sglang/version.py +1 -1
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
 
| 
         @@ -23,17 +23,20 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch 
     | 
|
| 
       23 
23 
     | 
    
         
             
            - ScheduleBatch is managed by `scheduler.py::Scheduler`.
         
     | 
| 
       24 
24 
     | 
    
         
             
              It contains high-level scheduling data. Most of the data is on the CPU.
         
     | 
| 
       25 
25 
     | 
    
         
             
            - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
         
     | 
| 
      
 26 
     | 
    
         
            +
              It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
         
     | 
| 
      
 27 
     | 
    
         
            +
              It will be transformed from CPU scheduler to GPU model runner.
         
     | 
| 
       26 
28 
     | 
    
         
             
            - ForwardBatch is managed by `model_runner.py::ModelRunner`.
         
     | 
| 
       27 
29 
     | 
    
         
             
              It contains low-level tensor data. Most of the data consists of GPU tensors.
         
     | 
| 
       28 
30 
     | 
    
         
             
            """
         
     | 
| 
       29 
31 
     | 
    
         | 
| 
      
 32 
     | 
    
         
            +
            import dataclasses
         
     | 
| 
       30 
33 
     | 
    
         
             
            import logging
         
     | 
| 
       31 
     | 
    
         
            -
            from dataclasses import dataclass
         
     | 
| 
       32 
34 
     | 
    
         
             
            from typing import List, Optional, Tuple, Union
         
     | 
| 
       33 
35 
     | 
    
         | 
| 
       34 
36 
     | 
    
         
             
            import torch
         
     | 
| 
       35 
37 
     | 
    
         | 
| 
       36 
38 
     | 
    
         
             
            from sglang.global_config import global_config
         
     | 
| 
      
 39 
     | 
    
         
            +
            from sglang.srt.configs.model_config import ModelConfig
         
     | 
| 
       37 
40 
     | 
    
         
             
            from sglang.srt.constrained import RegexGuide
         
     | 
| 
       38 
41 
     | 
    
         
             
            from sglang.srt.constrained.jump_forward import JumpForwardMap
         
     | 
| 
       39 
42 
     | 
    
         
             
            from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
         
     | 
| 
         @@ -114,38 +117,50 @@ class FINISH_ABORT(BaseFinishReason): 
     | 
|
| 
       114 
117 
     | 
    
         
             
                    }
         
     | 
| 
       115 
118 
     | 
    
         | 
| 
       116 
119 
     | 
    
         | 
| 
       117 
     | 
    
         
            -
            @dataclass
         
     | 
| 
      
 120 
     | 
    
         
            +
            @dataclasses.dataclass
         
     | 
| 
       118 
121 
     | 
    
         
             
            class ImageInputs:
         
     | 
| 
       119 
122 
     | 
    
         
             
                """The image related inputs."""
         
     | 
| 
       120 
123 
     | 
    
         | 
| 
       121 
124 
     | 
    
         
             
                pixel_values: torch.Tensor
         
     | 
| 
       122 
     | 
    
         
            -
                 
     | 
| 
      
 125 
     | 
    
         
            +
                image_hashes: Optional[list] = None
         
     | 
| 
       123 
126 
     | 
    
         
             
                image_sizes: Optional[list] = None
         
     | 
| 
       124 
127 
     | 
    
         
             
                image_offsets: Optional[list] = None
         
     | 
| 
       125 
128 
     | 
    
         
             
                pad_values: Optional[list] = None
         
     | 
| 
       126 
129 
     | 
    
         
             
                modalities: Optional[list] = None
         
     | 
| 
      
 130 
     | 
    
         
            +
                num_image_tokens: Optional[int] = None
         
     | 
| 
       127 
131 
     | 
    
         | 
| 
       128 
132 
     | 
    
         
             
                image_embeds: Optional[List[torch.Tensor]] = None
         
     | 
| 
       129 
133 
     | 
    
         
             
                aspect_ratio_ids: Optional[List[torch.Tensor]] = None
         
     | 
| 
       130 
134 
     | 
    
         
             
                aspect_ratio_mask: Optional[List[torch.Tensor]] = None
         
     | 
| 
      
 135 
     | 
    
         
            +
                # QWen2-VL related
         
     | 
| 
      
 136 
     | 
    
         
            +
                image_grid_thws: List[Tuple[int, int, int]] = None
         
     | 
| 
       131 
137 
     | 
    
         | 
| 
       132 
138 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       133 
139 
     | 
    
         
             
                def from_dict(obj, vocab_size):
         
     | 
| 
       134 
140 
     | 
    
         
             
                    # Use image hash as fake token_ids, which is then used for prefix matching
         
     | 
| 
       135 
141 
     | 
    
         
             
                    ret = ImageInputs(
         
     | 
| 
       136 
142 
     | 
    
         
             
                        pixel_values=obj["pixel_values"],
         
     | 
| 
       137 
     | 
    
         
            -
                         
     | 
| 
      
 143 
     | 
    
         
            +
                        image_hashes=hash(tuple(obj["image_hashes"])),
         
     | 
| 
       138 
144 
     | 
    
         
             
                    )
         
     | 
| 
       139 
     | 
    
         
            -
                    image_hash = ret. 
     | 
| 
      
 145 
     | 
    
         
            +
                    image_hash = ret.image_hashes
         
     | 
| 
       140 
146 
     | 
    
         
             
                    ret.pad_values = [
         
     | 
| 
       141 
147 
     | 
    
         
             
                        (image_hash) % vocab_size,
         
     | 
| 
       142 
148 
     | 
    
         
             
                        (image_hash >> 16) % vocab_size,
         
     | 
| 
       143 
149 
     | 
    
         
             
                        (image_hash >> 32) % vocab_size,
         
     | 
| 
       144 
150 
     | 
    
         
             
                        (image_hash >> 64) % vocab_size,
         
     | 
| 
       145 
151 
     | 
    
         
             
                    ]
         
     | 
| 
       146 
     | 
    
         
            -
             
     | 
| 
       147 
     | 
    
         
            -
                     
     | 
| 
       148 
     | 
    
         
            -
             
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                    optional_args = [
         
     | 
| 
      
 154 
     | 
    
         
            +
                        "image_sizes",
         
     | 
| 
      
 155 
     | 
    
         
            +
                        "modalities",
         
     | 
| 
      
 156 
     | 
    
         
            +
                        "aspect_ratio_ids",
         
     | 
| 
      
 157 
     | 
    
         
            +
                        "aspect_ratio_mask",
         
     | 
| 
      
 158 
     | 
    
         
            +
                        "image_grid_thws",
         
     | 
| 
      
 159 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 160 
     | 
    
         
            +
                    for arg in optional_args:
         
     | 
| 
      
 161 
     | 
    
         
            +
                        if arg in obj:
         
     | 
| 
      
 162 
     | 
    
         
            +
                            setattr(ret, arg, obj[arg])
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
       149 
164 
     | 
    
         
             
                    return ret
         
     | 
| 
       150 
165 
     | 
    
         | 
| 
       151 
166 
     | 
    
         | 
| 
         @@ -236,6 +251,9 @@ class Req: 
     | 
|
| 
       236 
251 
     | 
    
         
             
                    self.regex_fsm_state: int = 0
         
     | 
| 
       237 
252 
     | 
    
         
             
                    self.jump_forward_map: JumpForwardMap = None
         
     | 
| 
       238 
253 
     | 
    
         | 
| 
      
 254 
     | 
    
         
            +
                    # For Qwen2-VL
         
     | 
| 
      
 255 
     | 
    
         
            +
                    self.mrope_position_delta = []  # use mutable object
         
     | 
| 
      
 256 
     | 
    
         
            +
             
     | 
| 
       239 
257 
     | 
    
         
             
                # whether request reached finished condition
         
     | 
| 
       240 
258 
     | 
    
         
             
                def finished(self) -> bool:
         
     | 
| 
       241 
259 
     | 
    
         
             
                    return self.finished_reason is not None
         
     | 
| 
         @@ -316,15 +334,20 @@ class Req: 
     | 
|
| 
       316 
334 
     | 
    
         | 
| 
       317 
335 
     | 
    
         
             
                    last_token_id = self.output_ids[-1]
         
     | 
| 
       318 
336 
     | 
    
         | 
| 
       319 
     | 
    
         
            -
                    matched_eos =  
     | 
| 
      
 337 
     | 
    
         
            +
                    matched_eos = False
         
     | 
| 
       320 
338 
     | 
    
         | 
| 
      
 339 
     | 
    
         
            +
                    # Check stop token ids
         
     | 
| 
      
 340 
     | 
    
         
            +
                    if self.sampling_params.stop_token_ids:
         
     | 
| 
      
 341 
     | 
    
         
            +
                        matched_eos = last_token_id in self.sampling_params.stop_token_ids
         
     | 
| 
       321 
342 
     | 
    
         
             
                    if self.tokenizer is not None:
         
     | 
| 
       322 
343 
     | 
    
         
             
                        matched_eos |= last_token_id == self.tokenizer.eos_token_id
         
     | 
| 
       323 
     | 
    
         
            -
             
     | 
| 
      
 344 
     | 
    
         
            +
                        if self.tokenizer.additional_stop_token_ids:
         
     | 
| 
      
 345 
     | 
    
         
            +
                            matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
         
     | 
| 
       324 
346 
     | 
    
         
             
                    if matched_eos and not self.sampling_params.ignore_eos:
         
     | 
| 
       325 
347 
     | 
    
         
             
                        self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
         
     | 
| 
       326 
348 
     | 
    
         
             
                        return
         
     | 
| 
       327 
349 
     | 
    
         | 
| 
      
 350 
     | 
    
         
            +
                    # Check stop strings
         
     | 
| 
       328 
351 
     | 
    
         
             
                    if len(self.sampling_params.stop_strs) > 0:
         
     | 
| 
       329 
352 
     | 
    
         
             
                        tail_str = self.tokenizer.decode(
         
     | 
| 
       330 
353 
     | 
    
         
             
                            self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
         
     | 
| 
         @@ -399,7 +422,7 @@ class Req: 
     | 
|
| 
       399 
422 
     | 
    
         
             
            bid = 0
         
     | 
| 
       400 
423 
     | 
    
         | 
| 
       401 
424 
     | 
    
         | 
| 
       402 
     | 
    
         
            -
            @dataclass
         
     | 
| 
      
 425 
     | 
    
         
            +
            @dataclasses.dataclass
         
     | 
| 
       403 
426 
     | 
    
         
             
            class ScheduleBatch:
         
     | 
| 
       404 
427 
     | 
    
         
             
                """Store all inforamtion of a batch."""
         
     | 
| 
       405 
428 
     | 
    
         | 
| 
         @@ -409,6 +432,9 @@ class ScheduleBatch: 
     | 
|
| 
       409 
432 
     | 
    
         
             
                token_to_kv_pool: BaseTokenToKVPool = None
         
     | 
| 
       410 
433 
     | 
    
         
             
                tree_cache: BasePrefixCache = None
         
     | 
| 
       411 
434 
     | 
    
         | 
| 
      
 435 
     | 
    
         
            +
                # For utility
         
     | 
| 
      
 436 
     | 
    
         
            +
                model_config: ModelConfig = None
         
     | 
| 
      
 437 
     | 
    
         
            +
             
     | 
| 
       412 
438 
     | 
    
         
             
                forward_mode: ForwardMode = None
         
     | 
| 
       413 
439 
     | 
    
         
             
                sampling_info: SamplingBatchInfo = None
         
     | 
| 
       414 
440 
     | 
    
         | 
| 
         @@ -416,10 +442,13 @@ class ScheduleBatch: 
     | 
|
| 
       416 
442 
     | 
    
         
             
                input_ids: torch.Tensor = None
         
     | 
| 
       417 
443 
     | 
    
         
             
                req_pool_indices: torch.Tensor = None
         
     | 
| 
       418 
444 
     | 
    
         
             
                seq_lens: torch.Tensor = None
         
     | 
| 
      
 445 
     | 
    
         
            +
                # The output locations of the KV cache
         
     | 
| 
       419 
446 
     | 
    
         
             
                out_cache_loc: torch.Tensor = None
         
     | 
| 
       420 
     | 
    
         
            -
             
     | 
| 
       421 
447 
     | 
    
         
             
                output_ids: torch.Tensor = None
         
     | 
| 
       422 
448 
     | 
    
         | 
| 
      
 449 
     | 
    
         
            +
                # The sum of all sequence lengths
         
     | 
| 
      
 450 
     | 
    
         
            +
                seq_lens_sum: int = None
         
     | 
| 
      
 451 
     | 
    
         
            +
             
     | 
| 
       423 
452 
     | 
    
         
             
                # For processing logprobs
         
     | 
| 
       424 
453 
     | 
    
         
             
                return_logprob: bool = False
         
     | 
| 
       425 
454 
     | 
    
         
             
                top_logprobs_nums: Optional[List[int]] = None
         
     | 
| 
         @@ -428,33 +457,42 @@ class ScheduleBatch: 
     | 
|
| 
       428 
457 
     | 
    
         
             
                prefix_lens: List[int] = None
         
     | 
| 
       429 
458 
     | 
    
         
             
                extend_lens: List[int] = None
         
     | 
| 
       430 
459 
     | 
    
         
             
                extend_num_tokens: int = None
         
     | 
| 
       431 
     | 
    
         
            -
                running_bs: int = None
         
     | 
| 
       432 
460 
     | 
    
         
             
                decoding_reqs: List[Req] = None
         
     | 
| 
       433 
461 
     | 
    
         | 
| 
      
 462 
     | 
    
         
            +
                # For encoder-decoder
         
     | 
| 
      
 463 
     | 
    
         
            +
                encoder_cached: Optional[List[bool]] = None
         
     | 
| 
      
 464 
     | 
    
         
            +
                encoder_lens: Optional[torch.Tensor] = None
         
     | 
| 
      
 465 
     | 
    
         
            +
                encoder_lens_cpu: Optional[List[int]] = None
         
     | 
| 
      
 466 
     | 
    
         
            +
                encoder_out_cache_loc: Optional[torch.Tensor] = None
         
     | 
| 
      
 467 
     | 
    
         
            +
             
     | 
| 
       434 
468 
     | 
    
         
             
                # Stream
         
     | 
| 
       435 
469 
     | 
    
         
             
                has_stream: bool = False
         
     | 
| 
       436 
470 
     | 
    
         | 
| 
       437 
     | 
    
         
            -
                # device
         
     | 
| 
       438 
     | 
    
         
            -
                device: str = "cuda"
         
     | 
| 
       439 
     | 
    
         
            -
             
     | 
| 
       440 
471 
     | 
    
         
             
                # Has regex
         
     | 
| 
       441 
472 
     | 
    
         
             
                has_regex: bool = False
         
     | 
| 
       442 
473 
     | 
    
         | 
| 
       443 
     | 
    
         
            -
                 
     | 
| 
       444 
     | 
    
         
            -
                 
     | 
| 
       445 
     | 
    
         
            -
                    return_logprob = any(req.return_logprob for req in reqs)
         
     | 
| 
       446 
     | 
    
         
            -
                    has_stream = any(req.stream for req in reqs)
         
     | 
| 
       447 
     | 
    
         
            -
                    has_regex = any(req.regex_fsm for req in reqs)
         
     | 
| 
      
 474 
     | 
    
         
            +
                # device
         
     | 
| 
      
 475 
     | 
    
         
            +
                device: str = "cuda"
         
     | 
| 
       448 
476 
     | 
    
         | 
| 
      
 477 
     | 
    
         
            +
                @classmethod
         
     | 
| 
      
 478 
     | 
    
         
            +
                def init_new(
         
     | 
| 
      
 479 
     | 
    
         
            +
                    cls,
         
     | 
| 
      
 480 
     | 
    
         
            +
                    reqs,
         
     | 
| 
      
 481 
     | 
    
         
            +
                    req_to_token_pool,
         
     | 
| 
      
 482 
     | 
    
         
            +
                    token_to_kv_pool,
         
     | 
| 
      
 483 
     | 
    
         
            +
                    tree_cache,
         
     | 
| 
      
 484 
     | 
    
         
            +
                    model_config,
         
     | 
| 
      
 485 
     | 
    
         
            +
                ):
         
     | 
| 
       449 
486 
     | 
    
         
             
                    return cls(
         
     | 
| 
       450 
487 
     | 
    
         
             
                        reqs=reqs,
         
     | 
| 
       451 
488 
     | 
    
         
             
                        req_to_token_pool=req_to_token_pool,
         
     | 
| 
       452 
489 
     | 
    
         
             
                        token_to_kv_pool=token_to_kv_pool,
         
     | 
| 
       453 
490 
     | 
    
         
             
                        tree_cache=tree_cache,
         
     | 
| 
       454 
     | 
    
         
            -
                         
     | 
| 
       455 
     | 
    
         
            -
                         
     | 
| 
      
 491 
     | 
    
         
            +
                        model_config=model_config,
         
     | 
| 
      
 492 
     | 
    
         
            +
                        return_logprob=any(req.return_logprob for req in reqs),
         
     | 
| 
      
 493 
     | 
    
         
            +
                        has_stream=any(req.stream for req in reqs),
         
     | 
| 
      
 494 
     | 
    
         
            +
                        has_regex=any(req.regex_fsm for req in reqs),
         
     | 
| 
       456 
495 
     | 
    
         
             
                        device=req_to_token_pool.device,
         
     | 
| 
       457 
     | 
    
         
            -
                        has_regex=has_regex,
         
     | 
| 
       458 
496 
     | 
    
         
             
                    )
         
     | 
| 
       459 
497 
     | 
    
         | 
| 
       460 
498 
     | 
    
         
             
                def batch_size(self):
         
     | 
| 
         @@ -481,14 +519,90 @@ class ScheduleBatch: 
     | 
|
| 
       481 
519 
     | 
    
         
             
                            out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
         
     | 
| 
       482 
520 
     | 
    
         | 
| 
       483 
521 
     | 
    
         
             
                        if out_cache_loc is None:
         
     | 
| 
       484 
     | 
    
         
            -
                             
     | 
| 
      
 522 
     | 
    
         
            +
                            phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
         
     | 
| 
      
 523 
     | 
    
         
            +
                            logger.error(
         
     | 
| 
      
 524 
     | 
    
         
            +
                                f"{phase_str} out of memory. Try to lower your batch size.\n"
         
     | 
| 
      
 525 
     | 
    
         
            +
                                f"Try to allocate {num_tokens} tokens.\n"
         
     | 
| 
      
 526 
     | 
    
         
            +
                                f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
         
     | 
| 
      
 527 
     | 
    
         
            +
                            )
         
     | 
| 
       485 
528 
     | 
    
         
             
                            if self.tree_cache is not None:
         
     | 
| 
       486 
529 
     | 
    
         
             
                                self.tree_cache.pretty_print()
         
     | 
| 
       487 
530 
     | 
    
         
             
                            exit(1)
         
     | 
| 
       488 
531 
     | 
    
         | 
| 
       489 
532 
     | 
    
         
             
                    return out_cache_loc
         
     | 
| 
       490 
533 
     | 
    
         | 
| 
       491 
     | 
    
         
            -
                def  
     | 
| 
      
 534 
     | 
    
         
            +
                def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
         
     | 
| 
      
 535 
     | 
    
         
            +
                    self.encoder_lens_cpu = []
         
     | 
| 
      
 536 
     | 
    
         
            +
                    self.encoder_cached = []
         
     | 
| 
      
 537 
     | 
    
         
            +
             
     | 
| 
      
 538 
     | 
    
         
            +
                    for req in self.reqs:
         
     | 
| 
      
 539 
     | 
    
         
            +
                        im = req.image_inputs
         
     | 
| 
      
 540 
     | 
    
         
            +
                        if im is None or im.num_image_tokens is None:
         
     | 
| 
      
 541 
     | 
    
         
            +
                            # No image input
         
     | 
| 
      
 542 
     | 
    
         
            +
                            self.encoder_lens_cpu.append(0)
         
     | 
| 
      
 543 
     | 
    
         
            +
                            self.encoder_cached.append(True)
         
     | 
| 
      
 544 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 545 
     | 
    
         
            +
                            self.encoder_lens_cpu.append(im.num_image_tokens)
         
     | 
| 
      
 546 
     | 
    
         
            +
                            self.encoder_cached.append(
         
     | 
| 
      
 547 
     | 
    
         
            +
                                self.forward_mode.is_decode()
         
     | 
| 
      
 548 
     | 
    
         
            +
                                or len(req.prefix_indices) >= im.num_image_tokens
         
     | 
| 
      
 549 
     | 
    
         
            +
                            )
         
     | 
| 
      
 550 
     | 
    
         
            +
             
     | 
| 
      
 551 
     | 
    
         
            +
                    self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to(
         
     | 
| 
      
 552 
     | 
    
         
            +
                        self.device, non_blocking=True
         
     | 
| 
      
 553 
     | 
    
         
            +
                    )
         
     | 
| 
      
 554 
     | 
    
         
            +
             
     | 
| 
      
 555 
     | 
    
         
            +
                    # Strip encoder infos
         
     | 
| 
      
 556 
     | 
    
         
            +
                    pt = 0
         
     | 
| 
      
 557 
     | 
    
         
            +
                    decoder_out_cache_loc = []
         
     | 
| 
      
 558 
     | 
    
         
            +
                    encoder_out_cache_loc = []
         
     | 
| 
      
 559 
     | 
    
         
            +
                    for i, req in enumerate(self.reqs):
         
     | 
| 
      
 560 
     | 
    
         
            +
                        encoder_len = self.encoder_lens_cpu[i]
         
     | 
| 
      
 561 
     | 
    
         
            +
                        seq_lens[i] -= encoder_len
         
     | 
| 
      
 562 
     | 
    
         
            +
             
     | 
| 
      
 563 
     | 
    
         
            +
                        if len(req.prefix_indices) < encoder_len:
         
     | 
| 
      
 564 
     | 
    
         
            +
                            # NOTE: the encoder part should considered as a whole
         
     | 
| 
      
 565 
     | 
    
         
            +
                            assert len(req.prefix_indices) == 0
         
     | 
| 
      
 566 
     | 
    
         
            +
                            input_ids[i] = input_ids[i][encoder_len:]
         
     | 
| 
      
 567 
     | 
    
         
            +
                            encoder_out_cache_loc.append(self.out_cache_loc[pt : pt + encoder_len])
         
     | 
| 
      
 568 
     | 
    
         
            +
                            decoder_out_cache_loc.append(
         
     | 
| 
      
 569 
     | 
    
         
            +
                                self.out_cache_loc[pt + encoder_len : pt + req.extend_input_len]
         
     | 
| 
      
 570 
     | 
    
         
            +
                            )
         
     | 
| 
      
 571 
     | 
    
         
            +
                            self.extend_lens[i] -= encoder_len
         
     | 
| 
      
 572 
     | 
    
         
            +
                            self.extend_num_tokens -= encoder_len
         
     | 
| 
      
 573 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 574 
     | 
    
         
            +
                            decoder_out_cache_loc.append(
         
     | 
| 
      
 575 
     | 
    
         
            +
                                self.out_cache_loc[pt : pt + req.extend_input_len]
         
     | 
| 
      
 576 
     | 
    
         
            +
                            )
         
     | 
| 
      
 577 
     | 
    
         
            +
                            self.prefix_lens[i] -= encoder_len
         
     | 
| 
      
 578 
     | 
    
         
            +
             
     | 
| 
      
 579 
     | 
    
         
            +
                        pt += req.extend_input_len
         
     | 
| 
      
 580 
     | 
    
         
            +
             
     | 
| 
      
 581 
     | 
    
         
            +
                    # Reassign
         
     | 
| 
      
 582 
     | 
    
         
            +
                    self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to(
         
     | 
| 
      
 583 
     | 
    
         
            +
                        self.device, non_blocking=True
         
     | 
| 
      
 584 
     | 
    
         
            +
                    )
         
     | 
| 
      
 585 
     | 
    
         
            +
                    self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to(
         
     | 
| 
      
 586 
     | 
    
         
            +
                        self.device, non_blocking=True
         
     | 
| 
      
 587 
     | 
    
         
            +
                    )
         
     | 
| 
      
 588 
     | 
    
         
            +
             
     | 
| 
      
 589 
     | 
    
         
            +
                    if not decoder_out_cache_loc:
         
     | 
| 
      
 590 
     | 
    
         
            +
                        self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
         
     | 
| 
      
 591 
     | 
    
         
            +
                            self.device, non_blocking=True
         
     | 
| 
      
 592 
     | 
    
         
            +
                        )
         
     | 
| 
      
 593 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 594 
     | 
    
         
            +
                        self.out_cache_loc = torch.cat(decoder_out_cache_loc)
         
     | 
| 
      
 595 
     | 
    
         
            +
             
     | 
| 
      
 596 
     | 
    
         
            +
                    if not encoder_out_cache_loc:
         
     | 
| 
      
 597 
     | 
    
         
            +
                        self.encoder_out_cache_loc = torch.empty(0, dtype=torch.int32).to(
         
     | 
| 
      
 598 
     | 
    
         
            +
                            self.device, non_blocking=True
         
     | 
| 
      
 599 
     | 
    
         
            +
                        )
         
     | 
| 
      
 600 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 601 
     | 
    
         
            +
                        self.encoder_out_cache_loc = torch.cat(encoder_out_cache_loc)
         
     | 
| 
      
 602 
     | 
    
         
            +
             
     | 
| 
      
 603 
     | 
    
         
            +
                    assert len(self.out_cache_loc) == self.extend_num_tokens
         
     | 
| 
      
 604 
     | 
    
         
            +
             
     | 
| 
      
 605 
     | 
    
         
            +
                def prepare_for_extend(self):
         
     | 
| 
       492 
606 
     | 
    
         
             
                    self.forward_mode = ForwardMode.EXTEND
         
     | 
| 
       493 
607 
     | 
    
         | 
| 
       494 
608 
     | 
    
         
             
                    bs = len(self.reqs)
         
     | 
| 
         @@ -516,12 +630,12 @@ class ScheduleBatch: 
     | 
|
| 
       516 
630 
     | 
    
         
             
                        assert seq_len - pre_len == req.extend_input_len
         
     | 
| 
       517 
631 
     | 
    
         | 
| 
       518 
632 
     | 
    
         
             
                        if pre_len > 0:
         
     | 
| 
       519 
     | 
    
         
            -
                            self.req_to_token_pool. 
     | 
| 
       520 
     | 
    
         
            -
                                req.prefix_indices
         
     | 
| 
      
 633 
     | 
    
         
            +
                            self.req_to_token_pool.write(
         
     | 
| 
      
 634 
     | 
    
         
            +
                                (req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
         
     | 
| 
       521 
635 
     | 
    
         
             
                            )
         
     | 
| 
       522 
     | 
    
         
            -
             
     | 
| 
       523 
     | 
    
         
            -
             
     | 
| 
       524 
     | 
    
         
            -
                            out_cache_loc[pt : pt + req.extend_input_len]
         
     | 
| 
      
 636 
     | 
    
         
            +
                        self.req_to_token_pool.write(
         
     | 
| 
      
 637 
     | 
    
         
            +
                            (req.req_pool_idx, slice(pre_len, seq_len)),
         
     | 
| 
      
 638 
     | 
    
         
            +
                            out_cache_loc[pt : pt + req.extend_input_len],
         
     | 
| 
       525 
639 
     | 
    
         
             
                        )
         
     | 
| 
       526 
640 
     | 
    
         | 
| 
       527 
641 
     | 
    
         
             
                        # Compute the relative logprob_start_len in an extend batch
         
     | 
| 
         @@ -546,16 +660,23 @@ class ScheduleBatch: 
     | 
|
| 
       546 
660 
     | 
    
         
             
                        self.device, non_blocking=True
         
     | 
| 
       547 
661 
     | 
    
         
             
                    )
         
     | 
| 
       548 
662 
     | 
    
         | 
| 
       549 
     | 
    
         
            -
                    self.extend_num_tokens = extend_num_tokens
         
     | 
| 
       550 
663 
     | 
    
         
             
                    self.out_cache_loc = out_cache_loc
         
     | 
| 
      
 664 
     | 
    
         
            +
             
     | 
| 
      
 665 
     | 
    
         
            +
                    self.seq_lens_sum = sum(seq_lens)
         
     | 
| 
       551 
666 
     | 
    
         
             
                    if self.return_logprob:
         
     | 
| 
       552 
667 
     | 
    
         
             
                        self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
         
     | 
| 
      
 668 
     | 
    
         
            +
                    self.extend_num_tokens = extend_num_tokens
         
     | 
| 
       553 
669 
     | 
    
         
             
                    self.prefix_lens = [len(r.prefix_indices) for r in reqs]
         
     | 
| 
       554 
670 
     | 
    
         
             
                    self.extend_lens = [r.extend_input_len for r in reqs]
         
     | 
| 
       555 
671 
     | 
    
         
             
                    self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
         
     | 
| 
       556 
672 
     | 
    
         | 
| 
      
 673 
     | 
    
         
            +
                    if self.model_config.is_encoder_decoder:
         
     | 
| 
      
 674 
     | 
    
         
            +
                        self.prepare_encoder_info_extend(input_ids, seq_lens)
         
     | 
| 
      
 675 
     | 
    
         
            +
             
     | 
| 
       557 
676 
     | 
    
         
             
                    self.sampling_info = SamplingBatchInfo.from_schedule_batch(
         
     | 
| 
       558 
     | 
    
         
            -
                        self, 
     | 
| 
      
 677 
     | 
    
         
            +
                        self,
         
     | 
| 
      
 678 
     | 
    
         
            +
                        self.model_config.vocab_size,
         
     | 
| 
      
 679 
     | 
    
         
            +
                        global_server_args_dict["disable_penalizer"],
         
     | 
| 
       559 
680 
     | 
    
         
             
                    )
         
     | 
| 
       560 
681 
     | 
    
         | 
| 
       561 
682 
     | 
    
         
             
                def mix_with_running(self, running_batch: "ScheduleBatch"):
         
     | 
| 
         @@ -568,12 +689,11 @@ class ScheduleBatch: 
     | 
|
| 
       568 
689 
     | 
    
         | 
| 
       569 
690 
     | 
    
         
             
                    input_ids = torch.cat([self.input_ids, running_batch.input_ids])
         
     | 
| 
       570 
691 
     | 
    
         
             
                    out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
         
     | 
| 
       571 
     | 
    
         
            -
                    extend_num_tokens = self.extend_num_tokens + running_bs
         
     | 
| 
       572 
692 
     | 
    
         | 
| 
       573 
693 
     | 
    
         
             
                    self.merge_batch(running_batch)
         
     | 
| 
       574 
694 
     | 
    
         
             
                    self.input_ids = input_ids
         
     | 
| 
       575 
695 
     | 
    
         
             
                    self.out_cache_loc = out_cache_loc
         
     | 
| 
       576 
     | 
    
         
            -
                    self.extend_num_tokens  
     | 
| 
      
 696 
     | 
    
         
            +
                    self.extend_num_tokens += running_bs
         
     | 
| 
       577 
697 
     | 
    
         | 
| 
       578 
698 
     | 
    
         
             
                    # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
         
     | 
| 
       579 
699 
     | 
    
         
             
                    self.prefix_lens.extend(
         
     | 
| 
         @@ -631,8 +751,8 @@ class ScheduleBatch: 
     | 
|
| 
       631 
751 
     | 
    
         | 
| 
       632 
752 
     | 
    
         
             
                        if isinstance(self.tree_cache, ChunkCache):
         
     | 
| 
       633 
753 
     | 
    
         
             
                            # ChunkCache does not have eviction
         
     | 
| 
       634 
     | 
    
         
            -
                            token_indices = self.req_to_token_pool.req_to_token[ 
     | 
| 
       635 
     | 
    
         
            -
                                : seq_lens_cpu[idx]
         
     | 
| 
      
 754 
     | 
    
         
            +
                            token_indices = self.req_to_token_pool.req_to_token[
         
     | 
| 
      
 755 
     | 
    
         
            +
                                req.req_pool_idx, : seq_lens_cpu[idx]
         
     | 
| 
       636 
756 
     | 
    
         
             
                            ]
         
     | 
| 
       637 
757 
     | 
    
         
             
                            self.token_to_kv_pool.free(token_indices)
         
     | 
| 
       638 
758 
     | 
    
         
             
                            self.req_to_token_pool.free(req.req_pool_idx)
         
     | 
| 
         @@ -640,8 +760,8 @@ class ScheduleBatch: 
     | 
|
| 
       640 
760 
     | 
    
         
             
                        else:
         
     | 
| 
       641 
761 
     | 
    
         
             
                            # TODO: apply more fine-grained retraction
         
     | 
| 
       642 
762 
     | 
    
         
             
                            last_uncached_pos = len(req.prefix_indices)
         
     | 
| 
       643 
     | 
    
         
            -
                            token_indices = self.req_to_token_pool.req_to_token[ 
     | 
| 
       644 
     | 
    
         
            -
                                last_uncached_pos : seq_lens_cpu[idx]
         
     | 
| 
      
 763 
     | 
    
         
            +
                            token_indices = self.req_to_token_pool.req_to_token[
         
     | 
| 
      
 764 
     | 
    
         
            +
                                req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
         
     | 
| 
       645 
765 
     | 
    
         
             
                            ]
         
     | 
| 
       646 
766 
     | 
    
         
             
                            self.token_to_kv_pool.free(token_indices)
         
     | 
| 
       647 
767 
     | 
    
         
             
                            self.req_to_token_pool.free(req.req_pool_idx)
         
     | 
| 
         @@ -746,7 +866,11 @@ class ScheduleBatch: 
     | 
|
| 
       746 
866 
     | 
    
         | 
| 
       747 
867 
     | 
    
         
             
                    return jump_forward_reqs
         
     | 
| 
       748 
868 
     | 
    
         | 
| 
       749 
     | 
    
         
            -
                def  
     | 
| 
      
 869 
     | 
    
         
            +
                def prepare_encoder_info_decode(self):
         
     | 
| 
      
 870 
     | 
    
         
            +
                    # Reset the encoder cached status
         
     | 
| 
      
 871 
     | 
    
         
            +
                    self.encoder_cached = [True] * len(self.reqs)
         
     | 
| 
      
 872 
     | 
    
         
            +
             
     | 
| 
      
 873 
     | 
    
         
            +
                def prepare_for_decode(self, enable_overlap: bool = False):
         
     | 
| 
       750 
874 
     | 
    
         
             
                    self.forward_mode = ForwardMode.DECODE
         
     | 
| 
       751 
875 
     | 
    
         | 
| 
       752 
876 
     | 
    
         
             
                    self.input_ids = self.output_ids
         
     | 
| 
         @@ -760,10 +884,25 @@ class ScheduleBatch: 
     | 
|
| 
       760 
884 
     | 
    
         
             
                    bs = len(self.reqs)
         
     | 
| 
       761 
885 
     | 
    
         
             
                    self.out_cache_loc = self.alloc_token_slots(bs)
         
     | 
| 
       762 
886 
     | 
    
         | 
| 
       763 
     | 
    
         
            -
                    self. 
     | 
| 
       764 
     | 
    
         
            -
                        self. 
     | 
| 
       765 
     | 
    
         
            -
             
     | 
| 
       766 
     | 
    
         
            -
                     
     | 
| 
      
 887 
     | 
    
         
            +
                    if self.model_config.is_encoder_decoder:
         
     | 
| 
      
 888 
     | 
    
         
            +
                        locs = self.encoder_lens + self.seq_lens
         
     | 
| 
      
 889 
     | 
    
         
            +
                        self.prepare_encoder_info_decode()
         
     | 
| 
      
 890 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 891 
     | 
    
         
            +
                        locs = self.seq_lens
         
     | 
| 
      
 892 
     | 
    
         
            +
             
     | 
| 
      
 893 
     | 
    
         
            +
                    if enable_overlap:
         
     | 
| 
      
 894 
     | 
    
         
            +
                        # Do not use in-place operations in the overlap mode
         
     | 
| 
      
 895 
     | 
    
         
            +
                        self.req_to_token_pool.write(
         
     | 
| 
      
 896 
     | 
    
         
            +
                            (self.req_pool_indices, locs), self.out_cache_loc
         
     | 
| 
      
 897 
     | 
    
         
            +
                        )
         
     | 
| 
      
 898 
     | 
    
         
            +
                        self.seq_lens = self.seq_lens + 1
         
     | 
| 
      
 899 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 900 
     | 
    
         
            +
                        # A faster in-place version
         
     | 
| 
      
 901 
     | 
    
         
            +
                        self.req_to_token_pool.write(
         
     | 
| 
      
 902 
     | 
    
         
            +
                            (self.req_pool_indices, locs), self.out_cache_loc
         
     | 
| 
      
 903 
     | 
    
         
            +
                        )
         
     | 
| 
      
 904 
     | 
    
         
            +
                        self.seq_lens.add_(1)
         
     | 
| 
      
 905 
     | 
    
         
            +
                    self.seq_lens_sum += bs
         
     | 
| 
       767 
906 
     | 
    
         | 
| 
       768 
907 
     | 
    
         
             
                def filter_batch(
         
     | 
| 
       769 
908 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -787,6 +926,10 @@ class ScheduleBatch: 
     | 
|
| 
       787 
926 
     | 
    
         
             
                        # No need to filter
         
     | 
| 
       788 
927 
     | 
    
         
             
                        return
         
     | 
| 
       789 
928 
     | 
    
         | 
| 
      
 929 
     | 
    
         
            +
                    if self.model_config.is_encoder_decoder:
         
     | 
| 
      
 930 
     | 
    
         
            +
                        self.encoder_lens = self.encoder_lens[keep_indices]
         
     | 
| 
      
 931 
     | 
    
         
            +
                        self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices]
         
     | 
| 
      
 932 
     | 
    
         
            +
             
     | 
| 
       790 
933 
     | 
    
         
             
                    self.reqs = [self.reqs[i] for i in keep_indices]
         
     | 
| 
       791 
934 
     | 
    
         
             
                    new_indices = torch.tensor(keep_indices, dtype=torch.int32).to(
         
     | 
| 
       792 
935 
     | 
    
         
             
                        self.device, non_blocking=True
         
     | 
| 
         @@ -794,6 +937,7 @@ class ScheduleBatch: 
     | 
|
| 
       794 
937 
     | 
    
         
             
                    self.req_pool_indices = self.req_pool_indices[new_indices]
         
     | 
| 
       795 
938 
     | 
    
         
             
                    self.seq_lens = self.seq_lens[new_indices]
         
     | 
| 
       796 
939 
     | 
    
         
             
                    self.out_cache_loc = None
         
     | 
| 
      
 940 
     | 
    
         
            +
                    self.seq_lens_sum = self.seq_lens.sum().item()
         
     | 
| 
       797 
941 
     | 
    
         
             
                    self.output_ids = self.output_ids[new_indices]
         
     | 
| 
       798 
942 
     | 
    
         
             
                    self.return_logprob = any(req.return_logprob for req in self.reqs)
         
     | 
| 
       799 
943 
     | 
    
         
             
                    if self.return_logprob:
         
     | 
| 
         @@ -812,11 +956,17 @@ class ScheduleBatch: 
     | 
|
| 
       812 
956 
     | 
    
         
             
                    # needs to be called with pre-merged Batch.reqs.
         
     | 
| 
       813 
957 
     | 
    
         
             
                    self.sampling_info.merge_batch(other.sampling_info)
         
     | 
| 
       814 
958 
     | 
    
         | 
| 
      
 959 
     | 
    
         
            +
                    # Encoder-decoder infos
         
     | 
| 
      
 960 
     | 
    
         
            +
                    if self.model_config.is_encoder_decoder:
         
     | 
| 
      
 961 
     | 
    
         
            +
                        self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
         
     | 
| 
      
 962 
     | 
    
         
            +
                        self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
         
     | 
| 
      
 963 
     | 
    
         
            +
             
     | 
| 
       815 
964 
     | 
    
         
             
                    self.req_pool_indices = torch.concat(
         
     | 
| 
       816 
965 
     | 
    
         
             
                        [self.req_pool_indices, other.req_pool_indices]
         
     | 
| 
       817 
966 
     | 
    
         
             
                    )
         
     | 
| 
       818 
967 
     | 
    
         
             
                    self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
         
     | 
| 
       819 
968 
     | 
    
         
             
                    self.out_cache_loc = None
         
     | 
| 
      
 969 
     | 
    
         
            +
                    self.seq_lens_sum += other.seq_lens_sum
         
     | 
| 
       820 
970 
     | 
    
         
             
                    if self.output_ids is not None:
         
     | 
| 
       821 
971 
     | 
    
         
             
                        self.output_ids = torch.concat([self.output_ids, other.output_ids])
         
     | 
| 
       822 
972 
     | 
    
         
             
                    if self.return_logprob and other.return_logprob:
         
     | 
| 
         @@ -833,16 +983,12 @@ class ScheduleBatch: 
     | 
|
| 
       833 
983 
     | 
    
         | 
| 
       834 
984 
     | 
    
         
             
                def get_model_worker_batch(self):
         
     | 
| 
       835 
985 
     | 
    
         
             
                    if self.forward_mode.is_decode():
         
     | 
| 
       836 
     | 
    
         
            -
                        extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens =  
     | 
| 
       837 
     | 
    
         
            -
                            image_inputs
         
     | 
| 
       838 
     | 
    
         
            -
                        ) = None
         
     | 
| 
      
 986 
     | 
    
         
            +
                        extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
         
     | 
| 
       839 
987 
     | 
    
         
             
                    else:
         
     | 
| 
       840 
988 
     | 
    
         
             
                        extend_seq_lens = self.extend_lens
         
     | 
| 
       841 
989 
     | 
    
         
             
                        extend_prefix_lens = self.prefix_lens
         
     | 
| 
       842 
990 
     | 
    
         
             
                        extend_logprob_start_lens = self.extend_logprob_start_lens
         
     | 
| 
       843 
     | 
    
         
            -
                        image_inputs = [r.image_inputs for r in self.reqs]
         
     | 
| 
       844 
991 
     | 
    
         | 
| 
       845 
     | 
    
         
            -
                    lora_paths = [req.lora_path for req in self.reqs]
         
     | 
| 
       846 
992 
     | 
    
         
             
                    if self.has_regex:
         
     | 
| 
       847 
993 
     | 
    
         
             
                        self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
         
     | 
| 
       848 
994 
     | 
    
         
             
                        self.sampling_info.regex_fsm_states = [
         
     | 
| 
         @@ -854,6 +1000,8 @@ class ScheduleBatch: 
     | 
|
| 
       854 
1000 
     | 
    
         
             
                    global bid
         
     | 
| 
       855 
1001 
     | 
    
         
             
                    bid += 1
         
     | 
| 
       856 
1002 
     | 
    
         | 
| 
      
 1003 
     | 
    
         
            +
                    mrope_positions_delta = [req.mrope_position_delta for req in self.reqs]
         
     | 
| 
      
 1004 
     | 
    
         
            +
             
     | 
| 
       857 
1005 
     | 
    
         
             
                    return ModelWorkerBatch(
         
     | 
| 
       858 
1006 
     | 
    
         
             
                        bid=bid,
         
     | 
| 
       859 
1007 
     | 
    
         
             
                        forward_mode=self.forward_mode,
         
     | 
| 
         @@ -861,19 +1009,29 @@ class ScheduleBatch: 
     | 
|
| 
       861 
1009 
     | 
    
         
             
                        req_pool_indices=self.req_pool_indices,
         
     | 
| 
       862 
1010 
     | 
    
         
             
                        seq_lens=self.seq_lens,
         
     | 
| 
       863 
1011 
     | 
    
         
             
                        out_cache_loc=self.out_cache_loc,
         
     | 
| 
      
 1012 
     | 
    
         
            +
                        seq_lens_sum=self.seq_lens_sum,
         
     | 
| 
      
 1013 
     | 
    
         
            +
                        req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
         
     | 
| 
       864 
1014 
     | 
    
         
             
                        return_logprob=self.return_logprob,
         
     | 
| 
       865 
1015 
     | 
    
         
             
                        top_logprobs_nums=self.top_logprobs_nums,
         
     | 
| 
      
 1016 
     | 
    
         
            +
                        extend_num_tokens=self.extend_num_tokens,
         
     | 
| 
       866 
1017 
     | 
    
         
             
                        extend_seq_lens=extend_seq_lens,
         
     | 
| 
       867 
1018 
     | 
    
         
             
                        extend_prefix_lens=extend_prefix_lens,
         
     | 
| 
       868 
1019 
     | 
    
         
             
                        extend_logprob_start_lens=extend_logprob_start_lens,
         
     | 
| 
       869 
     | 
    
         
            -
                        image_inputs=image_inputs,
         
     | 
| 
       870 
     | 
    
         
            -
                         
     | 
| 
      
 1020 
     | 
    
         
            +
                        image_inputs=[r.image_inputs for r in self.reqs],
         
     | 
| 
      
 1021 
     | 
    
         
            +
                        encoder_cached=self.encoder_cached,
         
     | 
| 
      
 1022 
     | 
    
         
            +
                        encoder_lens=self.encoder_lens,
         
     | 
| 
      
 1023 
     | 
    
         
            +
                        encoder_lens_cpu=self.encoder_lens_cpu,
         
     | 
| 
      
 1024 
     | 
    
         
            +
                        encoder_out_cache_loc=self.encoder_out_cache_loc,
         
     | 
| 
      
 1025 
     | 
    
         
            +
                        lora_paths=[req.lora_path for req in self.reqs],
         
     | 
| 
       871 
1026 
     | 
    
         
             
                        sampling_info=self.sampling_info,
         
     | 
| 
      
 1027 
     | 
    
         
            +
                        mrope_positions_delta=mrope_positions_delta,
         
     | 
| 
       872 
1028 
     | 
    
         
             
                    )
         
     | 
| 
       873 
1029 
     | 
    
         | 
| 
       874 
1030 
     | 
    
         
             
                def copy(self):
         
     | 
| 
      
 1031 
     | 
    
         
            +
                    # Only contain fields that will be used by process_batch_result
         
     | 
| 
       875 
1032 
     | 
    
         
             
                    return ScheduleBatch(
         
     | 
| 
       876 
1033 
     | 
    
         
             
                        reqs=self.reqs,
         
     | 
| 
      
 1034 
     | 
    
         
            +
                        model_config=self.model_config,
         
     | 
| 
       877 
1035 
     | 
    
         
             
                        forward_mode=self.forward_mode,
         
     | 
| 
       878 
1036 
     | 
    
         
             
                        out_cache_loc=self.out_cache_loc,
         
     | 
| 
       879 
1037 
     | 
    
         
             
                        return_logprob=self.return_logprob,
         
     | 
| 
         @@ -887,7 +1045,7 @@ class ScheduleBatch: 
     | 
|
| 
       887 
1045 
     | 
    
         
             
                    )
         
     | 
| 
       888 
1046 
     | 
    
         | 
| 
       889 
1047 
     | 
    
         | 
| 
       890 
     | 
    
         
            -
            @dataclass
         
     | 
| 
      
 1048 
     | 
    
         
            +
            @dataclasses.dataclass
         
     | 
| 
       891 
1049 
     | 
    
         
             
            class ModelWorkerBatch:
         
     | 
| 
       892 
1050 
     | 
    
         
             
                # The batch id
         
     | 
| 
       893 
1051 
     | 
    
         
             
                bid: int
         
     | 
| 
         @@ -902,11 +1060,18 @@ class ModelWorkerBatch: 
     | 
|
| 
       902 
1060 
     | 
    
         
             
                # The indices of output tokens in the token_to_kv_pool
         
     | 
| 
       903 
1061 
     | 
    
         
             
                out_cache_loc: torch.Tensor
         
     | 
| 
       904 
1062 
     | 
    
         | 
| 
      
 1063 
     | 
    
         
            +
                # The sum of all sequence lengths
         
     | 
| 
      
 1064 
     | 
    
         
            +
                seq_lens_sum: int
         
     | 
| 
      
 1065 
     | 
    
         
            +
             
     | 
| 
      
 1066 
     | 
    
         
            +
                # The memory pool operation records
         
     | 
| 
      
 1067 
     | 
    
         
            +
                req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
         
     | 
| 
      
 1068 
     | 
    
         
            +
             
     | 
| 
       905 
1069 
     | 
    
         
             
                # For logprob
         
     | 
| 
       906 
1070 
     | 
    
         
             
                return_logprob: bool
         
     | 
| 
       907 
1071 
     | 
    
         
             
                top_logprobs_nums: Optional[List[int]]
         
     | 
| 
       908 
1072 
     | 
    
         | 
| 
       909 
1073 
     | 
    
         
             
                # For extend
         
     | 
| 
      
 1074 
     | 
    
         
            +
                extend_num_tokens: Optional[int]
         
     | 
| 
       910 
1075 
     | 
    
         
             
                extend_seq_lens: Optional[List[int]]
         
     | 
| 
       911 
1076 
     | 
    
         
             
                extend_prefix_lens: Optional[List[int]]
         
     | 
| 
       912 
1077 
     | 
    
         
             
                extend_logprob_start_lens: Optional[List[int]]
         
     | 
| 
         @@ -914,26 +1079,31 @@ class ModelWorkerBatch: 
     | 
|
| 
       914 
1079 
     | 
    
         
             
                # For multimodal
         
     | 
| 
       915 
1080 
     | 
    
         
             
                image_inputs: Optional[List[ImageInputs]]
         
     | 
| 
       916 
1081 
     | 
    
         | 
| 
      
 1082 
     | 
    
         
            +
                # For encoder-decoder
         
     | 
| 
      
 1083 
     | 
    
         
            +
                encoder_cached: Optional[List[bool]]
         
     | 
| 
      
 1084 
     | 
    
         
            +
                encoder_lens: Optional[torch.Tensor]
         
     | 
| 
      
 1085 
     | 
    
         
            +
                encoder_lens_cpu: Optional[List[int]]
         
     | 
| 
      
 1086 
     | 
    
         
            +
                encoder_out_cache_loc: Optional[torch.Tensor]
         
     | 
| 
      
 1087 
     | 
    
         
            +
             
     | 
| 
       917 
1088 
     | 
    
         
             
                # For LoRA
         
     | 
| 
       918 
1089 
     | 
    
         
             
                lora_paths: Optional[List[str]]
         
     | 
| 
       919 
1090 
     | 
    
         | 
| 
       920 
1091 
     | 
    
         
             
                # Sampling info
         
     | 
| 
       921 
1092 
     | 
    
         
             
                sampling_info: SamplingBatchInfo
         
     | 
| 
       922 
1093 
     | 
    
         | 
| 
      
 1094 
     | 
    
         
            +
                # For Qwen2-VL
         
     | 
| 
      
 1095 
     | 
    
         
            +
                mrope_positions_delta: List[List[int]]
         
     | 
| 
      
 1096 
     | 
    
         
            +
             
     | 
| 
       923 
1097 
     | 
    
         
             
                def copy(self):
         
     | 
| 
       924 
     | 
    
         
            -
                    return  
     | 
| 
       925 
     | 
    
         
            -
             
     | 
| 
       926 
     | 
    
         
            -
             
     | 
| 
       927 
     | 
    
         
            -
             
     | 
| 
       928 
     | 
    
         
            -
             
     | 
| 
       929 
     | 
    
         
            -
             
     | 
| 
       930 
     | 
    
         
            -
             
     | 
| 
       931 
     | 
    
         
            -
             
     | 
| 
       932 
     | 
    
         
            -
                         
     | 
| 
       933 
     | 
    
         
            -
                         
     | 
| 
       934 
     | 
    
         
            -
             
     | 
| 
       935 
     | 
    
         
            -
             
     | 
| 
       936 
     | 
    
         
            -
                        image_inputs=self.image_inputs,
         
     | 
| 
       937 
     | 
    
         
            -
                        lora_paths=self.lora_paths,
         
     | 
| 
       938 
     | 
    
         
            -
                        sampling_info=self.sampling_info.copy(),
         
     | 
| 
       939 
     | 
    
         
            -
                    )
         
     | 
| 
      
 1098 
     | 
    
         
            +
                    return dataclasses.replace(self, sampling_info=self.sampling_info.copy())
         
     | 
| 
      
 1099 
     | 
    
         
            +
             
     | 
| 
      
 1100 
     | 
    
         
            +
                def to(self, device: str):
         
     | 
| 
      
 1101 
     | 
    
         
            +
                    self.input_ids = self.input_ids.to(device, non_blocking=True)
         
     | 
| 
      
 1102 
     | 
    
         
            +
                    self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
         
     | 
| 
      
 1103 
     | 
    
         
            +
                    self.seq_lens = self.seq_lens.to(device, non_blocking=True)
         
     | 
| 
      
 1104 
     | 
    
         
            +
                    self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
         
     | 
| 
      
 1105 
     | 
    
         
            +
                    self.req_to_token_pool_records = [
         
     | 
| 
      
 1106 
     | 
    
         
            +
                        (x, y.to(device, non_blocking=True))
         
     | 
| 
      
 1107 
     | 
    
         
            +
                        for x, y in self.req_to_token_pool_records
         
     | 
| 
      
 1108 
     | 
    
         
            +
                    ]
         
     | 
| 
      
 1109 
     | 
    
         
            +
                    self.sampling_info.to(device)
         
     |