sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
 - sglang/lang/chat_template.py +17 -0
 - sglang/launch_server_llavavid.py +1 -1
 - sglang/srt/configs/__init__.py +3 -0
 - sglang/srt/configs/model_config.py +27 -2
 - sglang/srt/configs/qwen2vl.py +133 -0
 - sglang/srt/constrained/fsm_cache.py +10 -3
 - sglang/srt/conversation.py +27 -0
 - sglang/srt/hf_transformers_utils.py +16 -1
 - sglang/srt/layers/attention/__init__.py +16 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
 - sglang/srt/layers/attention/flashinfer_backend.py +174 -54
 - sglang/srt/layers/attention/triton_backend.py +22 -6
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
 - sglang/srt/layers/linear.py +89 -63
 - sglang/srt/layers/logits_processor.py +5 -5
 - sglang/srt/layers/rotary_embedding.py +112 -0
 - sglang/srt/layers/sampler.py +51 -39
 - sglang/srt/lora/lora.py +3 -1
 - sglang/srt/managers/data_parallel_controller.py +1 -1
 - sglang/srt/managers/detokenizer_manager.py +4 -0
 - sglang/srt/managers/image_processor.py +186 -13
 - sglang/srt/managers/io_struct.py +10 -0
 - sglang/srt/managers/schedule_batch.py +238 -68
 - sglang/srt/managers/scheduler.py +69 -50
 - sglang/srt/managers/tokenizer_manager.py +24 -4
 - sglang/srt/managers/tp_worker.py +26 -111
 - sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
 - sglang/srt/mem_cache/memory_pool.py +56 -10
 - sglang/srt/mem_cache/radix_cache.py +4 -3
 - sglang/srt/model_executor/cuda_graph_runner.py +87 -28
 - sglang/srt/model_executor/forward_batch_info.py +83 -3
 - sglang/srt/model_executor/model_runner.py +32 -11
 - sglang/srt/models/chatglm.py +3 -3
 - sglang/srt/models/deepseek_v2.py +2 -2
 - sglang/srt/models/mllama.py +1004 -0
 - sglang/srt/models/qwen2_vl.py +724 -0
 - sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
 - sglang/srt/sampling/sampling_batch_info.py +13 -3
 - sglang/srt/sampling/sampling_params.py +5 -7
 - sglang/srt/server.py +12 -0
 - sglang/srt/server_args.py +10 -0
 - sglang/srt/utils.py +22 -0
 - sglang/test/run_eval.py +2 -0
 - sglang/test/runners.py +20 -1
 - sglang/test/srt/sampling/penaltylib/utils.py +1 -0
 - sglang/test/test_utils.py +100 -3
 - sglang/version.py +1 -1
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
 - {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/managers/scheduler.py
    CHANGED
    
    | 
         @@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import ( 
     | 
|
| 
       38 
38 
     | 
    
         
             
                BatchEmbeddingOut,
         
     | 
| 
       39 
39 
     | 
    
         
             
                BatchTokenIDOut,
         
     | 
| 
       40 
40 
     | 
    
         
             
                FlushCacheReq,
         
     | 
| 
      
 41 
     | 
    
         
            +
                GetMemPoolSizeReq,
         
     | 
| 
      
 42 
     | 
    
         
            +
                GetMemPoolSizeReqOutput,
         
     | 
| 
       41 
43 
     | 
    
         
             
                ProfileReq,
         
     | 
| 
       42 
44 
     | 
    
         
             
                TokenizedEmbeddingReqInput,
         
     | 
| 
       43 
45 
     | 
    
         
             
                TokenizedGenerateReqInput,
         
     | 
| 
         @@ -51,6 +53,7 @@ from sglang.srt.managers.schedule_batch import ( 
     | 
|
| 
       51 
53 
     | 
    
         
             
                ImageInputs,
         
     | 
| 
       52 
54 
     | 
    
         
             
                Req,
         
     | 
| 
       53 
55 
     | 
    
         
             
                ScheduleBatch,
         
     | 
| 
      
 56 
     | 
    
         
            +
                global_server_args_dict,
         
     | 
| 
       54 
57 
     | 
    
         
             
            )
         
     | 
| 
       55 
58 
     | 
    
         
             
            from sglang.srt.managers.schedule_policy import (
         
     | 
| 
       56 
59 
     | 
    
         
             
                AddReqResult,
         
     | 
| 
         @@ -58,6 +61,7 @@ from sglang.srt.managers.schedule_policy import ( 
     | 
|
| 
       58 
61 
     | 
    
         
             
                SchedulePolicy,
         
     | 
| 
       59 
62 
     | 
    
         
             
            )
         
     | 
| 
       60 
63 
     | 
    
         
             
            from sglang.srt.managers.tp_worker import TpModelWorker
         
     | 
| 
      
 64 
     | 
    
         
            +
            from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
         
     | 
| 
       61 
65 
     | 
    
         
             
            from sglang.srt.mem_cache.chunk_cache import ChunkCache
         
     | 
| 
       62 
66 
     | 
    
         
             
            from sglang.srt.mem_cache.radix_cache import RadixCache
         
     | 
| 
       63 
67 
     | 
    
         
             
            from sglang.srt.server_args import PortArgs, ServerArgs
         
     | 
| 
         @@ -67,7 +71,6 @@ from sglang.srt.utils import ( 
     | 
|
| 
       67 
71 
     | 
    
         
             
                is_generation_model,
         
     | 
| 
       68 
72 
     | 
    
         
             
                is_multimodal_model,
         
     | 
| 
       69 
73 
     | 
    
         
             
                kill_parent_process,
         
     | 
| 
       70 
     | 
    
         
            -
                pytorch_profile,
         
     | 
| 
       71 
74 
     | 
    
         
             
                set_random_seed,
         
     | 
| 
       72 
75 
     | 
    
         
             
                suppress_other_loggers,
         
     | 
| 
       73 
76 
     | 
    
         
             
            )
         
     | 
| 
         @@ -91,6 +94,7 @@ class Scheduler: 
     | 
|
| 
       91 
94 
     | 
    
         
             
                    port_args: PortArgs,
         
     | 
| 
       92 
95 
     | 
    
         
             
                    gpu_id: int,
         
     | 
| 
       93 
96 
     | 
    
         
             
                    tp_rank: int,
         
     | 
| 
      
 97 
     | 
    
         
            +
                    dp_rank: Optional[int],
         
     | 
| 
       94 
98 
     | 
    
         
             
                ):
         
     | 
| 
       95 
99 
     | 
    
         
             
                    # Parse args
         
     | 
| 
       96 
100 
     | 
    
         
             
                    self.server_args = server_args
         
     | 
| 
         @@ -100,6 +104,7 @@ class Scheduler: 
     | 
|
| 
       100 
104 
     | 
    
         
             
                    self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
         
     | 
| 
       101 
105 
     | 
    
         
             
                    self.lora_paths = server_args.lora_paths
         
     | 
| 
       102 
106 
     | 
    
         
             
                    self.max_loras_per_batch = server_args.max_loras_per_batch
         
     | 
| 
      
 107 
     | 
    
         
            +
                    self.enable_overlap = server_args.enable_overlap_schedule
         
     | 
| 
       103 
108 
     | 
    
         | 
| 
       104 
109 
     | 
    
         
             
                    # Init inter-process communication
         
     | 
| 
       105 
110 
     | 
    
         
             
                    context = zmq.Context(2)
         
     | 
| 
         @@ -143,27 +148,37 @@ class Scheduler: 
     | 
|
| 
       143 
148 
     | 
    
         
             
                    )
         
     | 
| 
       144 
149 
     | 
    
         | 
| 
       145 
150 
     | 
    
         
             
                    # Launch a tensor parallel worker
         
     | 
| 
       146 
     | 
    
         
            -
                    self. 
     | 
| 
      
 151 
     | 
    
         
            +
                    if self.enable_overlap:
         
     | 
| 
      
 152 
     | 
    
         
            +
                        TpWorkerClass = TpModelWorkerClient
         
     | 
| 
      
 153 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 154 
     | 
    
         
            +
                        TpWorkerClass = TpModelWorker
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
                    self.tp_worker = TpWorkerClass(
         
     | 
| 
      
 157 
     | 
    
         
            +
                        server_args=server_args,
         
     | 
| 
       147 
158 
     | 
    
         
             
                        gpu_id=gpu_id,
         
     | 
| 
       148 
159 
     | 
    
         
             
                        tp_rank=tp_rank,
         
     | 
| 
       149 
     | 
    
         
            -
                         
     | 
| 
      
 160 
     | 
    
         
            +
                        dp_rank=dp_rank,
         
     | 
| 
       150 
161 
     | 
    
         
             
                        nccl_port=port_args.nccl_port,
         
     | 
| 
       151 
162 
     | 
    
         
             
                    )
         
     | 
| 
       152 
     | 
    
         
            -
                    self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
         
     | 
| 
       153 
     | 
    
         
            -
                    self.device = self.tp_worker.device
         
     | 
| 
       154 
163 
     | 
    
         | 
| 
       155 
164 
     | 
    
         
             
                    # Get token and memory info from the model worker
         
     | 
| 
       156 
165 
     | 
    
         
             
                    (
         
     | 
| 
       157 
166 
     | 
    
         
             
                        self.max_total_num_tokens,
         
     | 
| 
       158 
167 
     | 
    
         
             
                        self.max_prefill_tokens,
         
     | 
| 
       159 
168 
     | 
    
         
             
                        self.max_running_requests,
         
     | 
| 
      
 169 
     | 
    
         
            +
                        self.max_req_len,
         
     | 
| 
       160 
170 
     | 
    
         
             
                        self.max_req_input_len,
         
     | 
| 
       161 
171 
     | 
    
         
             
                        self.random_seed,
         
     | 
| 
       162 
     | 
    
         
            -
             
     | 
| 
      
 172 
     | 
    
         
            +
                        self.device,
         
     | 
| 
      
 173 
     | 
    
         
            +
                        worker_global_server_args_dict,
         
     | 
| 
      
 174 
     | 
    
         
            +
                        _,
         
     | 
| 
      
 175 
     | 
    
         
            +
                        _,
         
     | 
| 
      
 176 
     | 
    
         
            +
                        _,
         
     | 
| 
      
 177 
     | 
    
         
            +
                    ) = self.tp_worker.get_worker_info()
         
     | 
| 
      
 178 
     | 
    
         
            +
                    self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
         
     | 
| 
      
 179 
     | 
    
         
            +
                    self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
         
     | 
| 
      
 180 
     | 
    
         
            +
                    global_server_args_dict.update(worker_global_server_args_dict)
         
     | 
| 
       163 
181 
     | 
    
         
             
                    set_random_seed(self.random_seed)
         
     | 
| 
       164 
     | 
    
         
            -
                    self.pad_input_ids_func = getattr(
         
     | 
| 
       165 
     | 
    
         
            -
                        self.tp_worker.model_runner.model, "pad_input_ids", None
         
     | 
| 
       166 
     | 
    
         
            -
                    )
         
     | 
| 
       167 
182 
     | 
    
         | 
| 
       168 
183 
     | 
    
         
             
                    # Print debug info
         
     | 
| 
       169 
184 
     | 
    
         
             
                    logger.info(
         
     | 
| 
         @@ -173,9 +188,8 @@ class Scheduler: 
     | 
|
| 
       173 
188 
     | 
    
         
             
                        f"context_len={self.model_config.context_len}"
         
     | 
| 
       174 
189 
     | 
    
         
             
                    )
         
     | 
| 
       175 
190 
     | 
    
         | 
| 
       176 
     | 
    
         
            -
                    # Init cache
         
     | 
| 
       177 
     | 
    
         
            -
                    self.req_to_token_pool = self.tp_worker. 
     | 
| 
       178 
     | 
    
         
            -
                    self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
         
     | 
| 
      
 191 
     | 
    
         
            +
                    # Init memory pool and cache
         
     | 
| 
      
 192 
     | 
    
         
            +
                    self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
         
     | 
| 
       179 
193 
     | 
    
         | 
| 
       180 
194 
     | 
    
         
             
                    if (
         
     | 
| 
       181 
195 
     | 
    
         
             
                        server_args.chunked_prefill_size is not None
         
     | 
| 
         @@ -253,22 +267,9 @@ class Scheduler: 
     | 
|
| 
       253 
267 
     | 
    
         
             
                            with_stack=True,
         
     | 
| 
       254 
268 
     | 
    
         
             
                        )
         
     | 
| 
       255 
269 
     | 
    
         | 
| 
       256 
     | 
    
         
            -
                    # Init states for overlap schedule
         
     | 
| 
       257 
     | 
    
         
            -
                    if self.server_args.enable_overlap_schedule:
         
     | 
| 
       258 
     | 
    
         
            -
                        self.forward_batch_generation = (
         
     | 
| 
       259 
     | 
    
         
            -
                            self.tp_worker.forward_batch_generation_non_blocking
         
     | 
| 
       260 
     | 
    
         
            -
                        )
         
     | 
| 
       261 
     | 
    
         
            -
                        self.resolve_next_token_ids = (
         
     | 
| 
       262 
     | 
    
         
            -
                            lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
         
     | 
| 
       263 
     | 
    
         
            -
                        )
         
     | 
| 
       264 
     | 
    
         
            -
                        self.cache_finished_req = self.tree_cache.cache_finished_req
         
     | 
| 
       265 
     | 
    
         
            -
                    else:
         
     | 
| 
       266 
     | 
    
         
            -
                        self.forward_batch_generation = self.tp_worker.forward_batch_generation
         
     | 
| 
       267 
     | 
    
         
            -
                        self.resolve_next_token_ids = lambda bid, x: x.tolist()
         
     | 
| 
       268 
     | 
    
         
            -
                        self.cache_finished_req = self.tree_cache.cache_finished_req
         
     | 
| 
       269 
     | 
    
         
            -
             
     | 
| 
       270 
270 
     | 
    
         
             
                @torch.inference_mode()
         
     | 
| 
       271 
271 
     | 
    
         
             
                def event_loop_normal(self):
         
     | 
| 
      
 272 
     | 
    
         
            +
                    """A normal blocking scheduler loop."""
         
     | 
| 
       272 
273 
     | 
    
         
             
                    self.last_batch = None
         
     | 
| 
       273 
274 
     | 
    
         | 
| 
       274 
275 
     | 
    
         
             
                    while True:
         
     | 
| 
         @@ -299,6 +300,7 @@ class Scheduler: 
     | 
|
| 
       299 
300 
     | 
    
         | 
| 
       300 
301 
     | 
    
         
             
                @torch.inference_mode()
         
     | 
| 
       301 
302 
     | 
    
         
             
                def event_loop_overlap(self):
         
     | 
| 
      
 303 
     | 
    
         
            +
                    """A scheduler loop that overlaps the CPU processing and GPU computation."""
         
     | 
| 
       302 
304 
     | 
    
         
             
                    result_queue = deque()
         
     | 
| 
       303 
305 
     | 
    
         | 
| 
       304 
306 
     | 
    
         
             
                    self.last_batch = None
         
     | 
| 
         @@ -362,6 +364,10 @@ class Scheduler: 
     | 
|
| 
       362 
364 
     | 
    
         
             
                                self.start_profile()
         
     | 
| 
       363 
365 
     | 
    
         
             
                            else:
         
     | 
| 
       364 
366 
     | 
    
         
             
                                self.stop_profile()
         
     | 
| 
      
 367 
     | 
    
         
            +
                        elif isinstance(recv_req, GetMemPoolSizeReq):
         
     | 
| 
      
 368 
     | 
    
         
            +
                            self.send_to_detokenizer.send_pyobj(
         
     | 
| 
      
 369 
     | 
    
         
            +
                                GetMemPoolSizeReqOutput(self.max_total_num_tokens)
         
     | 
| 
      
 370 
     | 
    
         
            +
                            )
         
     | 
| 
       365 
371 
     | 
    
         
             
                        else:
         
     | 
| 
       366 
372 
     | 
    
         
             
                            raise ValueError(f"Invalid request: {recv_req}")
         
     | 
| 
       367 
373 
     | 
    
         | 
| 
         @@ -415,19 +421,20 @@ class Scheduler: 
     | 
|
| 
       415 
421 
     | 
    
         
             
                            )
         
     | 
| 
       416 
422 
     | 
    
         | 
| 
       417 
423 
     | 
    
         
             
                    # Truncate prompts that are too long
         
     | 
| 
       418 
     | 
    
         
            -
                    if len(req.origin_input_ids)  
     | 
| 
      
 424 
     | 
    
         
            +
                    if len(req.origin_input_ids) > self.max_req_input_len:
         
     | 
| 
       419 
425 
     | 
    
         
             
                        logger.warning(
         
     | 
| 
       420 
426 
     | 
    
         
             
                            "Request length is longer than the KV cache pool size or "
         
     | 
| 
       421 
427 
     | 
    
         
             
                            "the max context length. Truncated!!!"
         
     | 
| 
       422 
428 
     | 
    
         
             
                        )
         
     | 
| 
       423 
429 
     | 
    
         
             
                        req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
         
     | 
| 
      
 430 
     | 
    
         
            +
             
     | 
| 
       424 
431 
     | 
    
         
             
                    req.sampling_params.max_new_tokens = min(
         
     | 
| 
       425 
432 
     | 
    
         
             
                        (
         
     | 
| 
       426 
433 
     | 
    
         
             
                            req.sampling_params.max_new_tokens
         
     | 
| 
       427 
434 
     | 
    
         
             
                            if req.sampling_params.max_new_tokens is not None
         
     | 
| 
       428 
435 
     | 
    
         
             
                            else 1 << 30
         
     | 
| 
       429 
436 
     | 
    
         
             
                        ),
         
     | 
| 
       430 
     | 
    
         
            -
                        self. 
     | 
| 
      
 437 
     | 
    
         
            +
                        self.max_req_len - len(req.origin_input_ids) - 1,
         
     | 
| 
       431 
438 
     | 
    
         
             
                    )
         
     | 
| 
       432 
439 
     | 
    
         | 
| 
       433 
440 
     | 
    
         
             
                    self.waiting_queue.append(req)
         
     | 
| 
         @@ -575,6 +582,7 @@ class Scheduler: 
     | 
|
| 
       575 
582 
     | 
    
         
             
                            else set([])
         
     | 
| 
       576 
583 
     | 
    
         
             
                        )
         
     | 
| 
       577 
584 
     | 
    
         | 
| 
      
 585 
     | 
    
         
            +
                    # Get requests from the waiting queue to a new prefill batch
         
     | 
| 
       578 
586 
     | 
    
         
             
                    for req in self.waiting_queue:
         
     | 
| 
       579 
587 
     | 
    
         
             
                        if (
         
     | 
| 
       580 
588 
     | 
    
         
             
                            self.lora_paths
         
     | 
| 
         @@ -661,12 +669,13 @@ class Scheduler: 
     | 
|
| 
       661 
669 
     | 
    
         
             
                        self.req_to_token_pool,
         
     | 
| 
       662 
670 
     | 
    
         
             
                        self.token_to_kv_pool,
         
     | 
| 
       663 
671 
     | 
    
         
             
                        self.tree_cache,
         
     | 
| 
      
 672 
     | 
    
         
            +
                        self.model_config,
         
     | 
| 
       664 
673 
     | 
    
         
             
                    )
         
     | 
| 
       665 
     | 
    
         
            -
                    new_batch.prepare_for_extend( 
     | 
| 
      
 674 
     | 
    
         
            +
                    new_batch.prepare_for_extend()
         
     | 
| 
       666 
675 
     | 
    
         | 
| 
       667 
676 
     | 
    
         
             
                    # Mixed-style chunked prefill
         
     | 
| 
       668 
677 
     | 
    
         
             
                    if self.is_mixed_chunk and self.running_batch is not None:
         
     | 
| 
       669 
     | 
    
         
            -
                        self.running_batch.prepare_for_decode()
         
     | 
| 
      
 678 
     | 
    
         
            +
                        self.running_batch.prepare_for_decode(self.enable_overlap)
         
     | 
| 
       670 
679 
     | 
    
         
             
                        new_batch.mix_with_running(self.running_batch)
         
     | 
| 
       671 
680 
     | 
    
         
             
                        new_batch.decoding_reqs = self.running_batch.reqs
         
     | 
| 
       672 
681 
     | 
    
         
             
                        self.running_batch = None
         
     | 
| 
         @@ -676,6 +685,7 @@ class Scheduler: 
     | 
|
| 
       676 
685 
     | 
    
         
             
                    return new_batch
         
     | 
| 
       677 
686 
     | 
    
         | 
| 
       678 
687 
     | 
    
         
             
                def update_running_batch(self):
         
     | 
| 
      
 688 
     | 
    
         
            +
                    """Update the current running decoding batch."""
         
     | 
| 
       679 
689 
     | 
    
         
             
                    global test_retract
         
     | 
| 
       680 
690 
     | 
    
         
             
                    batch = self.running_batch
         
     | 
| 
       681 
691 
     | 
    
         | 
| 
         @@ -712,13 +722,14 @@ class Scheduler: 
     | 
|
| 
       712 
722 
     | 
    
         
             
                            return
         
     | 
| 
       713 
723 
     | 
    
         | 
| 
       714 
724 
     | 
    
         
             
                    # Update batch tensors
         
     | 
| 
       715 
     | 
    
         
            -
                    batch.prepare_for_decode()
         
     | 
| 
      
 725 
     | 
    
         
            +
                    batch.prepare_for_decode(self.enable_overlap)
         
     | 
| 
       716 
726 
     | 
    
         | 
| 
       717 
727 
     | 
    
         
             
                def run_batch(self, batch: ScheduleBatch):
         
     | 
| 
      
 728 
     | 
    
         
            +
                    """Run a batch."""
         
     | 
| 
       718 
729 
     | 
    
         
             
                    if self.is_generation:
         
     | 
| 
       719 
730 
     | 
    
         
             
                        if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
         
     | 
| 
       720 
731 
     | 
    
         
             
                            model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
       721 
     | 
    
         
            -
                            logits_output, next_token_ids = self.forward_batch_generation(
         
     | 
| 
      
 732 
     | 
    
         
            +
                            logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
         
     | 
| 
       722 
733 
     | 
    
         
             
                                model_worker_batch
         
     | 
| 
       723 
734 
     | 
    
         
             
                            )
         
     | 
| 
       724 
735 
     | 
    
         
             
                        else:
         
     | 
| 
         @@ -749,9 +760,12 @@ class Scheduler: 
     | 
|
| 
       749 
760 
     | 
    
         
             
                def process_batch_result_prefill(self, batch: ScheduleBatch, result):
         
     | 
| 
       750 
761 
     | 
    
         
             
                    if self.is_generation:
         
     | 
| 
       751 
762 
     | 
    
         
             
                        logits_output, next_token_ids, bid = result
         
     | 
| 
       752 
     | 
    
         
            -
             
     | 
| 
       753 
     | 
    
         
            -
             
     | 
| 
       754 
     | 
    
         
            -
                             
     | 
| 
      
 763 
     | 
    
         
            +
             
     | 
| 
      
 764 
     | 
    
         
            +
                        if self.enable_overlap:
         
     | 
| 
      
 765 
     | 
    
         
            +
                            logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
         
     | 
| 
      
 766 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 767 
     | 
    
         
            +
                            # Move next_token_ids and logprobs to cpu
         
     | 
| 
      
 768 
     | 
    
         
            +
                            if batch.return_logprob:
         
     | 
| 
       755 
769 
     | 
    
         
             
                                logits_output.next_token_logprobs = (
         
     | 
| 
       756 
770 
     | 
    
         
             
                                    logits_output.next_token_logprobs[
         
     | 
| 
       757 
771 
     | 
    
         
             
                                        torch.arange(len(next_token_ids), device=self.device),
         
     | 
| 
         @@ -764,8 +778,7 @@ class Scheduler: 
     | 
|
| 
       764 
778 
     | 
    
         
             
                                logits_output.normalized_prompt_logprobs = (
         
     | 
| 
       765 
779 
     | 
    
         
             
                                    logits_output.normalized_prompt_logprobs.tolist()
         
     | 
| 
       766 
780 
     | 
    
         
             
                                )
         
     | 
| 
       767 
     | 
    
         
            -
             
     | 
| 
       768 
     | 
    
         
            -
                        next_token_ids = self.resolve_next_token_ids(bid, next_token_ids)
         
     | 
| 
      
 781 
     | 
    
         
            +
                            next_token_ids = next_token_ids.tolist()
         
     | 
| 
       769 
782 
     | 
    
         | 
| 
       770 
783 
     | 
    
         
             
                        # Check finish conditions
         
     | 
| 
       771 
784 
     | 
    
         
             
                        logprob_pt = 0
         
     | 
| 
         @@ -779,7 +792,7 @@ class Scheduler: 
     | 
|
| 
       779 
792 
     | 
    
         
             
                                req.check_finished()
         
     | 
| 
       780 
793 
     | 
    
         | 
| 
       781 
794 
     | 
    
         
             
                                if req.finished():
         
     | 
| 
       782 
     | 
    
         
            -
                                    self.cache_finished_req(req)
         
     | 
| 
      
 795 
     | 
    
         
            +
                                    self.tree_cache.cache_finished_req(req)
         
     | 
| 
       783 
796 
     | 
    
         
             
                                elif not batch.decoding_reqs or req not in batch.decoding_reqs:
         
     | 
| 
       784 
797 
     | 
    
         
             
                                    self.tree_cache.cache_unfinished_req(req)
         
     | 
| 
       785 
798 
     | 
    
         | 
| 
         @@ -808,7 +821,7 @@ class Scheduler: 
     | 
|
| 
       808 
821 
     | 
    
         
             
                                req.check_finished()
         
     | 
| 
       809 
822 
     | 
    
         | 
| 
       810 
823 
     | 
    
         
             
                            if req.finished():
         
     | 
| 
       811 
     | 
    
         
            -
                                self.cache_finished_req(req)
         
     | 
| 
      
 824 
     | 
    
         
            +
                                self.tree_cache.cache_finished_req(req)
         
     | 
| 
       812 
825 
     | 
    
         
             
                            else:
         
     | 
| 
       813 
826 
     | 
    
         
             
                                self.tree_cache.cache_unfinished_req(req)
         
     | 
| 
       814 
827 
     | 
    
         | 
| 
         @@ -818,14 +831,17 @@ class Scheduler: 
     | 
|
| 
       818 
831 
     | 
    
         
             
                    logits_output, next_token_ids, bid = result
         
     | 
| 
       819 
832 
     | 
    
         
             
                    self.num_generated_tokens += len(batch.reqs)
         
     | 
| 
       820 
833 
     | 
    
         | 
| 
       821 
     | 
    
         
            -
                     
     | 
| 
       822 
     | 
    
         
            -
             
     | 
| 
       823 
     | 
    
         
            -
                        next_token_logprobs = logits_output.next_token_logprobs 
     | 
| 
       824 
     | 
    
         
            -
             
     | 
| 
       825 
     | 
    
         
            -
             
     | 
| 
       826 
     | 
    
         
            -
                         
     | 
| 
       827 
     | 
    
         
            -
             
     | 
| 
       828 
     | 
    
         
            -
             
     | 
| 
      
 834 
     | 
    
         
            +
                    if self.enable_overlap:
         
     | 
| 
      
 835 
     | 
    
         
            +
                        logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
         
     | 
| 
      
 836 
     | 
    
         
            +
                        next_token_logprobs = logits_output.next_token_logprobs
         
     | 
| 
      
 837 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 838 
     | 
    
         
            +
                        # Move next_token_ids and logprobs to cpu
         
     | 
| 
      
 839 
     | 
    
         
            +
                        if batch.return_logprob:
         
     | 
| 
      
 840 
     | 
    
         
            +
                            next_token_logprobs = logits_output.next_token_logprobs[
         
     | 
| 
      
 841 
     | 
    
         
            +
                                torch.arange(len(next_token_ids), device=self.device),
         
     | 
| 
      
 842 
     | 
    
         
            +
                                next_token_ids,
         
     | 
| 
      
 843 
     | 
    
         
            +
                            ].tolist()
         
     | 
| 
      
 844 
     | 
    
         
            +
                        next_token_ids = next_token_ids.tolist()
         
     | 
| 
       829 
845 
     | 
    
         | 
| 
       830 
846 
     | 
    
         
             
                    self.token_to_kv_pool.free_group_begin()
         
     | 
| 
       831 
847 
     | 
    
         | 
| 
         @@ -845,7 +861,7 @@ class Scheduler: 
     | 
|
| 
       845 
861 
     | 
    
         
             
                            )
         
     | 
| 
       846 
862 
     | 
    
         | 
| 
       847 
863 
     | 
    
         
             
                        if req.finished():
         
     | 
| 
       848 
     | 
    
         
            -
                            self.cache_finished_req(req)
         
     | 
| 
      
 864 
     | 
    
         
            +
                            self.tree_cache.cache_finished_req(req)
         
     | 
| 
       849 
865 
     | 
    
         | 
| 
       850 
866 
     | 
    
         
             
                        if req.return_logprob:
         
     | 
| 
       851 
867 
     | 
    
         
             
                            req.output_token_logprobs.append(
         
     | 
| 
         @@ -936,6 +952,7 @@ class Scheduler: 
     | 
|
| 
       936 
952 
     | 
    
         
             
                    return num_input_logprobs
         
     | 
| 
       937 
953 
     | 
    
         | 
| 
       938 
954 
     | 
    
         
             
                def stream_output(self, reqs: List[Req]):
         
     | 
| 
      
 955 
     | 
    
         
            +
                    """Stream the output to detokenizer."""
         
     | 
| 
       939 
956 
     | 
    
         
             
                    output_rids = []
         
     | 
| 
       940 
957 
     | 
    
         
             
                    output_meta_info = []
         
     | 
| 
       941 
958 
     | 
    
         
             
                    output_finished_reason: List[BaseFinishReason] = []
         
     | 
| 
         @@ -1033,6 +1050,7 @@ class Scheduler: 
     | 
|
| 
       1033 
1050 
     | 
    
         
             
                            )
         
     | 
| 
       1034 
1051 
     | 
    
         | 
| 
       1035 
1052 
     | 
    
         
             
                def flush_cache(self):
         
     | 
| 
      
 1053 
     | 
    
         
            +
                    """Flush the memory pool and cache."""
         
     | 
| 
       1036 
1054 
     | 
    
         
             
                    if len(self.waiting_queue) == 0 and (
         
     | 
| 
       1037 
1055 
     | 
    
         
             
                        self.running_batch is None or len(self.running_batch.reqs) == 0
         
     | 
| 
       1038 
1056 
     | 
    
         
             
                    ):
         
     | 
| 
         @@ -1069,10 +1087,11 @@ class Scheduler: 
     | 
|
| 
       1069 
1087 
     | 
    
         
             
                        for req in self.running_batch.reqs:
         
     | 
| 
       1070 
1088 
     | 
    
         
             
                            if req.rid == recv_req.rid and not req.finished():
         
     | 
| 
       1071 
1089 
     | 
    
         
             
                                req.finished_reason = FINISH_ABORT()
         
     | 
| 
       1072 
     | 
    
         
            -
                                self.cache_finished_req(req)
         
     | 
| 
      
 1090 
     | 
    
         
            +
                                self.tree_cache.cache_finished_req(req)
         
     | 
| 
       1073 
1091 
     | 
    
         
             
                                break
         
     | 
| 
       1074 
1092 
     | 
    
         | 
| 
       1075 
1093 
     | 
    
         
             
                def update_weights(self, recv_req: UpdateWeightReqInput):
         
     | 
| 
      
 1094 
     | 
    
         
            +
                    """In-place update of the weights."""
         
     | 
| 
       1076 
1095 
     | 
    
         
             
                    success, message = self.tp_worker.update_weights(recv_req)
         
     | 
| 
       1077 
1096 
     | 
    
         
             
                    if success:
         
     | 
| 
       1078 
1097 
     | 
    
         
             
                        flash_cache_success = self.flush_cache()
         
     | 
| 
         @@ -1112,7 +1131,7 @@ def run_scheduler_process( 
     | 
|
| 
       1112 
1131 
     | 
    
         
             
                suppress_other_loggers()
         
     | 
| 
       1113 
1132 
     | 
    
         | 
| 
       1114 
1133 
     | 
    
         
             
                try:
         
     | 
| 
       1115 
     | 
    
         
            -
                    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
         
     | 
| 
      
 1134 
     | 
    
         
            +
                    scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
         
     | 
| 
       1116 
1135 
     | 
    
         
             
                    pipe_writer.send("ready")
         
     | 
| 
       1117 
1136 
     | 
    
         
             
                    if server_args.enable_overlap_schedule:
         
     | 
| 
       1118 
1137 
     | 
    
         
             
                        scheduler.event_loop_overlap()
         
     | 
| 
         @@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import ( 
     | 
|
| 
       46 
46 
     | 
    
         
             
                EmbeddingReqInput,
         
     | 
| 
       47 
47 
     | 
    
         
             
                FlushCacheReq,
         
     | 
| 
       48 
48 
     | 
    
         
             
                GenerateReqInput,
         
     | 
| 
      
 49 
     | 
    
         
            +
                GetMemPoolSizeReq,
         
     | 
| 
      
 50 
     | 
    
         
            +
                GetMemPoolSizeReqOutput,
         
     | 
| 
       49 
51 
     | 
    
         
             
                ProfileReq,
         
     | 
| 
       50 
52 
     | 
    
         
             
                RewardReqInput,
         
     | 
| 
       51 
53 
     | 
    
         
             
                TokenizedEmbeddingReqInput,
         
     | 
| 
         @@ -122,7 +124,7 @@ class TokenizerManager: 
     | 
|
| 
       122 
124 
     | 
    
         | 
| 
       123 
125 
     | 
    
         
             
                            # We want to parallelize the image pre-processing so we create an executor for it
         
     | 
| 
       124 
126 
     | 
    
         
             
                            self.image_processor = get_image_processor(
         
     | 
| 
       125 
     | 
    
         
            -
                                self.hf_config, server_args, self.processor 
     | 
| 
      
 127 
     | 
    
         
            +
                                self.hf_config, server_args, self.processor
         
     | 
| 
       126 
128 
     | 
    
         
             
                            )
         
     | 
| 
       127 
129 
     | 
    
         
             
                        else:
         
     | 
| 
       128 
130 
     | 
    
         
             
                            self.tokenizer = get_tokenizer(
         
     | 
| 
         @@ -191,8 +193,10 @@ class TokenizerManager: 
     | 
|
| 
       191 
193 
     | 
    
         
             
                            sampling_params = self._get_sampling_params(obj.sampling_params)
         
     | 
| 
       192 
194 
     | 
    
         
             
                            if self.is_generation:
         
     | 
| 
       193 
195 
     | 
    
         
             
                                image_inputs = await self.image_processor.process_images_async(
         
     | 
| 
       194 
     | 
    
         
            -
                                    obj.image_data, obj
         
     | 
| 
      
 196 
     | 
    
         
            +
                                    obj.image_data, input_text or input_ids, obj
         
     | 
| 
       195 
197 
     | 
    
         
             
                                )
         
     | 
| 
      
 198 
     | 
    
         
            +
                                if image_inputs and "input_ids" in image_inputs:
         
     | 
| 
      
 199 
     | 
    
         
            +
                                    input_ids = image_inputs["input_ids"]
         
     | 
| 
       196 
200 
     | 
    
         
             
                                return_logprob = obj.return_logprob
         
     | 
| 
       197 
201 
     | 
    
         
             
                                logprob_start_len = obj.logprob_start_len
         
     | 
| 
       198 
202 
     | 
    
         
             
                                top_logprobs_num = obj.top_logprobs_num
         
     | 
| 
         @@ -217,8 +221,10 @@ class TokenizerManager: 
     | 
|
| 
       217 
221 
     | 
    
         
             
                            sampling_params = self._get_sampling_params(obj.sampling_params[index])
         
     | 
| 
       218 
222 
     | 
    
         
             
                            if self.is_generation:
         
     | 
| 
       219 
223 
     | 
    
         
             
                                image_inputs = await self.image_processor.process_images_async(
         
     | 
| 
       220 
     | 
    
         
            -
                                    obj.image_data[index], obj
         
     | 
| 
      
 224 
     | 
    
         
            +
                                    obj.image_data[index], input_text or input_ids, obj
         
     | 
| 
       221 
225 
     | 
    
         
             
                                )
         
     | 
| 
      
 226 
     | 
    
         
            +
                                if image_inputs and "input_ids" in image_inputs:
         
     | 
| 
      
 227 
     | 
    
         
            +
                                    input_ids = image_inputs["input_ids"]
         
     | 
| 
       222 
228 
     | 
    
         
             
                                return_logprob = obj.return_logprob[index]
         
     | 
| 
       223 
229 
     | 
    
         
             
                                logprob_start_len = obj.logprob_start_len[index]
         
     | 
| 
       224 
230 
     | 
    
         
             
                                top_logprobs_num = obj.top_logprobs_num[index]
         
     | 
| 
         @@ -263,8 +269,10 @@ class TokenizerManager: 
     | 
|
| 
       263 
269 
     | 
    
         
             
                        sampling_params = SamplingParams(**obj.sampling_params[0])
         
     | 
| 
       264 
270 
     | 
    
         
             
                        sampling_params.max_new_tokens = 0
         
     | 
| 
       265 
271 
     | 
    
         
             
                        image_inputs = await self.image_processor.process_images_async(
         
     | 
| 
       266 
     | 
    
         
            -
                            obj.image_data[0], obj
         
     | 
| 
      
 272 
     | 
    
         
            +
                            obj.image_data[0], input_text or input_ids, obj
         
     | 
| 
       267 
273 
     | 
    
         
             
                        )
         
     | 
| 
      
 274 
     | 
    
         
            +
                        if image_inputs and "input_ids" in image_inputs:
         
     | 
| 
      
 275 
     | 
    
         
            +
                            input_ids = image_inputs["input_ids"]
         
     | 
| 
       268 
276 
     | 
    
         
             
                        return_logprob = obj.return_logprob[0]
         
     | 
| 
       269 
277 
     | 
    
         
             
                        logprob_start_len = obj.logprob_start_len[0]
         
     | 
| 
       270 
278 
     | 
    
         
             
                        top_logprobs_num = obj.top_logprobs_num[0]
         
     | 
| 
         @@ -525,6 +533,15 @@ class TokenizerManager: 
     | 
|
| 
       525 
533 
     | 
    
         
             
                    req = ProfileReq.STOP_PROFILE
         
     | 
| 
       526 
534 
     | 
    
         
             
                    self.send_to_scheduler.send_pyobj(req)
         
     | 
| 
       527 
535 
     | 
    
         | 
| 
      
 536 
     | 
    
         
            +
                async def get_memory_pool_size(self):
         
     | 
| 
      
 537 
     | 
    
         
            +
                    if self.to_create_loop:
         
     | 
| 
      
 538 
     | 
    
         
            +
                        self.create_handle_loop()
         
     | 
| 
      
 539 
     | 
    
         
            +
             
     | 
| 
      
 540 
     | 
    
         
            +
                    req = GetMemPoolSizeReq()
         
     | 
| 
      
 541 
     | 
    
         
            +
                    self.send_to_scheduler.send_pyobj(req)
         
     | 
| 
      
 542 
     | 
    
         
            +
                    self.mem_pool_size = asyncio.Future()
         
     | 
| 
      
 543 
     | 
    
         
            +
                    return await self.mem_pool_size
         
     | 
| 
      
 544 
     | 
    
         
            +
             
     | 
| 
       528 
545 
     | 
    
         
             
                async def update_weights(
         
     | 
| 
       529 
546 
     | 
    
         
             
                    self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
         
     | 
| 
       530 
547 
     | 
    
         
             
                ):
         
     | 
| 
         @@ -584,6 +601,9 @@ class TokenizerManager: 
     | 
|
| 
       584 
601 
     | 
    
         
             
                        if isinstance(recv_obj, UpdateWeightReqOutput):
         
     | 
| 
       585 
602 
     | 
    
         
             
                            self.model_update_result.set_result(recv_obj)
         
     | 
| 
       586 
603 
     | 
    
         
             
                            continue
         
     | 
| 
      
 604 
     | 
    
         
            +
                        elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
         
     | 
| 
      
 605 
     | 
    
         
            +
                            self.mem_pool_size.set_result(recv_obj)
         
     | 
| 
      
 606 
     | 
    
         
            +
                            continue
         
     | 
| 
       587 
607 
     | 
    
         | 
| 
       588 
608 
     | 
    
         
             
                        assert isinstance(
         
     | 
| 
       589 
609 
     | 
    
         
             
                            recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
         
     | 
    
        sglang/srt/managers/tp_worker.py
    CHANGED
    
    | 
         @@ -17,16 +17,12 @@ limitations under the License. 
     | 
|
| 
       17 
17 
     | 
    
         | 
| 
       18 
18 
     | 
    
         
             
            import json
         
     | 
| 
       19 
19 
     | 
    
         
             
            import logging
         
     | 
| 
       20 
     | 
    
         
            -
            import  
     | 
| 
       21 
     | 
    
         
            -
            import time
         
     | 
| 
       22 
     | 
    
         
            -
            from queue import Queue
         
     | 
| 
       23 
     | 
    
         
            -
             
     | 
| 
       24 
     | 
    
         
            -
            import torch
         
     | 
| 
      
 20 
     | 
    
         
            +
            from typing import Optional
         
     | 
| 
       25 
21 
     | 
    
         | 
| 
       26 
22 
     | 
    
         
             
            from sglang.srt.configs.model_config import ModelConfig
         
     | 
| 
       27 
23 
     | 
    
         
             
            from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
         
     | 
| 
       28 
24 
     | 
    
         
             
            from sglang.srt.managers.io_struct import UpdateWeightReqInput
         
     | 
| 
       29 
     | 
    
         
            -
            from sglang.srt.managers.schedule_batch import ModelWorkerBatch
         
     | 
| 
      
 25 
     | 
    
         
            +
            from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
         
     | 
| 
       30 
26 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
       31 
27 
     | 
    
         
             
            from sglang.srt.model_executor.model_runner import ModelRunner
         
     | 
| 
       32 
28 
     | 
    
         
             
            from sglang.srt.server_args import ServerArgs
         
     | 
| 
         @@ -40,9 +36,10 @@ class TpModelWorker: 
     | 
|
| 
       40 
36 
     | 
    
         | 
| 
       41 
37 
     | 
    
         
             
                def __init__(
         
     | 
| 
       42 
38 
     | 
    
         
             
                    self,
         
     | 
| 
      
 39 
     | 
    
         
            +
                    server_args: ServerArgs,
         
     | 
| 
       43 
40 
     | 
    
         
             
                    gpu_id: int,
         
     | 
| 
       44 
41 
     | 
    
         
             
                    tp_rank: int,
         
     | 
| 
       45 
     | 
    
         
            -
                     
     | 
| 
      
 42 
     | 
    
         
            +
                    dp_rank: Optional[int],
         
     | 
| 
       46 
43 
     | 
    
         
             
                    nccl_port: int,
         
     | 
| 
       47 
44 
     | 
    
         
             
                ):
         
     | 
| 
       48 
45 
     | 
    
         
             
                    # Parse args
         
     | 
| 
         @@ -93,10 +90,14 @@ class TpModelWorker: 
     | 
|
| 
       93 
90 
     | 
    
         
             
                        ),
         
     | 
| 
       94 
91 
     | 
    
         
             
                        self.model_runner.req_to_token_pool.size,
         
     | 
| 
       95 
92 
     | 
    
         
             
                    )
         
     | 
| 
       96 
     | 
    
         
            -
                    self. 
     | 
| 
      
 93 
     | 
    
         
            +
                    self.max_req_len = min(
         
     | 
| 
       97 
94 
     | 
    
         
             
                        self.model_config.context_len - 1,
         
     | 
| 
       98 
95 
     | 
    
         
             
                        self.max_total_num_tokens - 1,
         
     | 
| 
       99 
96 
     | 
    
         
             
                    )
         
     | 
| 
      
 97 
     | 
    
         
            +
                    self.max_req_input_len = self.max_req_len - 5
         
     | 
| 
      
 98 
     | 
    
         
            +
                    assert (
         
     | 
| 
      
 99 
     | 
    
         
            +
                        self.max_req_len > 0 and self.max_req_input_len > 0
         
     | 
| 
      
 100 
     | 
    
         
            +
                    ), "Memory pool size is too small"
         
     | 
| 
       100 
101 
     | 
    
         | 
| 
       101 
102 
     | 
    
         
             
                    # Sync random seed across TP workers
         
     | 
| 
       102 
103 
     | 
    
         
             
                    self.random_seed = broadcast_pyobj(
         
     | 
| 
         @@ -106,92 +107,32 @@ class TpModelWorker: 
     | 
|
| 
       106 
107 
     | 
    
         
             
                    )[0]
         
     | 
| 
       107 
108 
     | 
    
         
             
                    set_random_seed(self.random_seed)
         
     | 
| 
       108 
109 
     | 
    
         | 
| 
       109 
     | 
    
         
            -
             
     | 
| 
       110 
     | 
    
         
            -
                        self.init_overlap_status()
         
     | 
| 
       111 
     | 
    
         
            -
             
     | 
| 
       112 
     | 
    
         
            -
                def get_token_and_memory_info(self):
         
     | 
| 
      
 110 
     | 
    
         
            +
                def get_worker_info(self):
         
     | 
| 
       113 
111 
     | 
    
         
             
                    return (
         
     | 
| 
       114 
112 
     | 
    
         
             
                        self.max_total_num_tokens,
         
     | 
| 
       115 
113 
     | 
    
         
             
                        self.max_prefill_tokens,
         
     | 
| 
       116 
114 
     | 
    
         
             
                        self.max_running_requests,
         
     | 
| 
      
 115 
     | 
    
         
            +
                        self.max_req_len,
         
     | 
| 
       117 
116 
     | 
    
         
             
                        self.max_req_input_len,
         
     | 
| 
       118 
117 
     | 
    
         
             
                        self.random_seed,
         
     | 
| 
      
 118 
     | 
    
         
            +
                        self.device,
         
     | 
| 
      
 119 
     | 
    
         
            +
                        global_server_args_dict,
         
     | 
| 
      
 120 
     | 
    
         
            +
                        self.model_runner.req_to_token_pool.size,
         
     | 
| 
      
 121 
     | 
    
         
            +
                        self.model_runner.req_to_token_pool.max_context_len,
         
     | 
| 
      
 122 
     | 
    
         
            +
                        self.model_runner.token_to_kv_pool.size,
         
     | 
| 
       119 
123 
     | 
    
         
             
                    )
         
     | 
| 
       120 
124 
     | 
    
         | 
| 
       121 
     | 
    
         
            -
                def  
     | 
| 
       122 
     | 
    
         
            -
                    self. 
     | 
| 
       123 
     | 
    
         
            -
             
     | 
| 
       124 
     | 
    
         
            -
             
     | 
| 
       125 
     | 
    
         
            -
                    self. 
     | 
| 
       126 
     | 
    
         
            -
             
     | 
| 
       127 
     | 
    
         
            -
             
     | 
| 
       128 
     | 
    
         
            -
                     
     | 
| 
       129 
     | 
    
         
            -
             
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
                    self.future_event_map = dict()
         
     | 
| 
       132 
     | 
    
         
            -
                    self.forward_queue = Queue()
         
     | 
| 
       133 
     | 
    
         
            -
                    self.forward_stream = torch.cuda.Stream()
         
     | 
| 
       134 
     | 
    
         
            -
                    self.forward_thread = threading.Thread(
         
     | 
| 
       135 
     | 
    
         
            -
                        target=self.forward_thread_func,
         
     | 
| 
      
 125 
     | 
    
         
            +
                def get_pad_input_ids_func(self):
         
     | 
| 
      
 126 
     | 
    
         
            +
                    return getattr(self.model_runner.model, "pad_input_ids", None)
         
     | 
| 
      
 127 
     | 
    
         
            +
             
     | 
| 
      
 128 
     | 
    
         
            +
                def get_tp_cpu_group(self):
         
     | 
| 
      
 129 
     | 
    
         
            +
                    return self.model_runner.tp_group.cpu_group
         
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
                def get_memory_pool(self):
         
     | 
| 
      
 132 
     | 
    
         
            +
                    return (
         
     | 
| 
      
 133 
     | 
    
         
            +
                        self.model_runner.req_to_token_pool,
         
     | 
| 
      
 134 
     | 
    
         
            +
                        self.model_runner.token_to_kv_pool,
         
     | 
| 
       136 
135 
     | 
    
         
             
                    )
         
     | 
| 
       137 
     | 
    
         
            -
                    self.forward_thread.start()
         
     | 
| 
       138 
     | 
    
         
            -
             
     | 
| 
       139 
     | 
    
         
            -
                def forward_thread_func(self):
         
     | 
| 
       140 
     | 
    
         
            -
                    with torch.cuda.stream(self.forward_stream):
         
     | 
| 
       141 
     | 
    
         
            -
                        self.forward_thread_func_()
         
     | 
| 
       142 
     | 
    
         
            -
             
     | 
| 
       143 
     | 
    
         
            -
                @torch.inference_mode()
         
     | 
| 
       144 
     | 
    
         
            -
                def forward_thread_func_(self):
         
     | 
| 
       145 
     | 
    
         
            -
                    while True:
         
     | 
| 
       146 
     | 
    
         
            -
                        tic1 = time.time()
         
     | 
| 
       147 
     | 
    
         
            -
                        model_worker_batch, future_logits_output, future_next_token_ids = (
         
     | 
| 
       148 
     | 
    
         
            -
                            self.forward_queue.get()
         
     | 
| 
       149 
     | 
    
         
            -
                        )
         
     | 
| 
       150 
     | 
    
         
            -
             
     | 
| 
       151 
     | 
    
         
            -
                        # Resolve future tokens in the input
         
     | 
| 
       152 
     | 
    
         
            -
                        tic2 = time.time()
         
     | 
| 
       153 
     | 
    
         
            -
                        resolved_input_ids = model_worker_batch.input_ids
         
     | 
| 
       154 
     | 
    
         
            -
                        future_mask = resolved_input_ids < 0
         
     | 
| 
       155 
     | 
    
         
            -
                        resolved_input_ids[future_mask] = self.future_token_ids_map[
         
     | 
| 
       156 
     | 
    
         
            -
                            -resolved_input_ids[future_mask]
         
     | 
| 
       157 
     | 
    
         
            -
                        ]
         
     | 
| 
       158 
     | 
    
         
            -
             
     | 
| 
       159 
     | 
    
         
            -
                        # Run forward
         
     | 
| 
       160 
     | 
    
         
            -
                        logits_output, next_token_ids = self.forward_batch_generation(
         
     | 
| 
       161 
     | 
    
         
            -
                            model_worker_batch
         
     | 
| 
       162 
     | 
    
         
            -
                        )
         
     | 
| 
       163 
     | 
    
         
            -
             
     | 
| 
       164 
     | 
    
         
            -
                        # Set future values
         
     | 
| 
       165 
     | 
    
         
            -
                        if model_worker_batch.return_logprob:
         
     | 
| 
       166 
     | 
    
         
            -
                            self.future_logits_output_dict[future_logits_output] = logits_output
         
     | 
| 
       167 
     | 
    
         
            -
             
     | 
| 
       168 
     | 
    
         
            -
                        # logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
         
     | 
| 
       169 
     | 
    
         
            -
                        self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
         
     | 
| 
       170 
     | 
    
         
            -
                            torch.int32
         
     | 
| 
       171 
     | 
    
         
            -
                        )
         
     | 
| 
       172 
     | 
    
         
            -
                        # logger.info("Set event")
         
     | 
| 
       173 
     | 
    
         
            -
                        self.future_token_ids_output[model_worker_batch.bid] = (
         
     | 
| 
       174 
     | 
    
         
            -
                            next_token_ids.tolist()
         
     | 
| 
       175 
     | 
    
         
            -
                        )
         
     | 
| 
       176 
     | 
    
         
            -
                        self.future_event_map[model_worker_batch.bid].set()
         
     | 
| 
       177 
     | 
    
         
            -
             
     | 
| 
       178 
     | 
    
         
            -
                        if False:
         
     | 
| 
       179 
     | 
    
         
            -
                            tic3 = time.time()
         
     | 
| 
       180 
     | 
    
         
            -
                            self.acc_time_with_waiting += tic3 - tic1
         
     | 
| 
       181 
     | 
    
         
            -
                            self.acc_time_without_waiting += tic3 - tic2
         
     | 
| 
       182 
     | 
    
         
            -
                            if self.forward_queue.qsize() == 0:
         
     | 
| 
       183 
     | 
    
         
            -
                                logger.info(
         
     | 
| 
       184 
     | 
    
         
            -
                                    f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
         
     | 
| 
       185 
     | 
    
         
            -
                                )
         
     | 
| 
       186 
     | 
    
         
            -
             
     | 
| 
       187 
     | 
    
         
            -
                def resolve_future_token_ids(self, bid: int):
         
     | 
| 
       188 
     | 
    
         
            -
                    self.future_event_map[bid].wait()
         
     | 
| 
       189 
     | 
    
         
            -
                    ret = self.future_token_ids_output[bid]
         
     | 
| 
       190 
     | 
    
         
            -
                    del self.future_event_map[bid]
         
     | 
| 
       191 
     | 
    
         
            -
                    return ret
         
     | 
| 
       192 
     | 
    
         
            -
             
     | 
| 
       193 
     | 
    
         
            -
                def resolve_future_logits_output(self, future_obj):
         
     | 
| 
       194 
     | 
    
         
            -
                    return self.future_logits_output_dict.pop(future_obj)
         
     | 
| 
       195 
136 
     | 
    
         | 
| 
       196 
137 
     | 
    
         
             
                def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
         
     | 
| 
       197 
138 
     | 
    
         
             
                    forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
         
     | 
| 
         @@ -205,32 +146,6 @@ class TpModelWorker: 
     | 
|
| 
       205 
146 
     | 
    
         
             
                    embeddings = logits_output.embeddings
         
     | 
| 
       206 
147 
     | 
    
         
             
                    return embeddings
         
     | 
| 
       207 
148 
     | 
    
         | 
| 
       208 
     | 
    
         
            -
                def forward_batch_generation_non_blocking(
         
     | 
| 
       209 
     | 
    
         
            -
                    self, model_worker_batch: ModelWorkerBatch
         
     | 
| 
       210 
     | 
    
         
            -
                ):
         
     | 
| 
       211 
     | 
    
         
            -
                    # Allocate output future objects
         
     | 
| 
       212 
     | 
    
         
            -
                    future_logits_output = self.future_logits_output_ct
         
     | 
| 
       213 
     | 
    
         
            -
                    self.future_logits_output_ct += 1
         
     | 
| 
       214 
     | 
    
         
            -
             
     | 
| 
       215 
     | 
    
         
            -
                    bs = len(model_worker_batch.seq_lens)
         
     | 
| 
       216 
     | 
    
         
            -
                    with torch.cuda.stream(self.forward_stream):
         
     | 
| 
       217 
     | 
    
         
            -
                        future_next_token_ids = -torch.arange(
         
     | 
| 
       218 
     | 
    
         
            -
                            self.future_token_ids_ct + 1,
         
     | 
| 
       219 
     | 
    
         
            -
                            self.future_token_ids_ct + 1 + bs,
         
     | 
| 
       220 
     | 
    
         
            -
                            dtype=torch.int32,
         
     | 
| 
       221 
     | 
    
         
            -
                            device=self.device,
         
     | 
| 
       222 
     | 
    
         
            -
                        )
         
     | 
| 
       223 
     | 
    
         
            -
                    self.future_token_ids_ct = (
         
     | 
| 
       224 
     | 
    
         
            -
                        self.future_token_ids_ct + bs
         
     | 
| 
       225 
     | 
    
         
            -
                    ) % self.future_token_ids_limit
         
     | 
| 
       226 
     | 
    
         
            -
                    ret = future_logits_output, future_next_token_ids
         
     | 
| 
       227 
     | 
    
         
            -
             
     | 
| 
       228 
     | 
    
         
            -
                    self.future_event_map[model_worker_batch.bid] = threading.Event()
         
     | 
| 
       229 
     | 
    
         
            -
                    self.forward_queue.put(
         
     | 
| 
       230 
     | 
    
         
            -
                        (model_worker_batch.copy(), future_logits_output, future_next_token_ids)
         
     | 
| 
       231 
     | 
    
         
            -
                    )
         
     | 
| 
       232 
     | 
    
         
            -
                    return ret
         
     | 
| 
       233 
     | 
    
         
            -
             
     | 
| 
       234 
149 
     | 
    
         
             
                def update_weights(self, recv_req: UpdateWeightReqInput):
         
     | 
| 
       235 
150 
     | 
    
         
             
                    success, message = self.model_runner.update_weights(
         
     | 
| 
       236 
151 
     | 
    
         
             
                        recv_req.model_path, recv_req.load_format
         
     |