sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
 - sglang/lang/interpreter.py +71 -1
 - sglang/lang/ir.py +2 -0
 - sglang/srt/configs/__init__.py +4 -0
 - sglang/srt/configs/chatglm.py +78 -0
 - sglang/srt/configs/dbrx.py +279 -0
 - sglang/srt/configs/model_config.py +16 -7
 - sglang/srt/hf_transformers_utils.py +9 -14
 - sglang/srt/layers/attention/__init__.py +8 -1
 - sglang/srt/layers/attention/flashinfer_backend.py +21 -5
 - sglang/srt/layers/linear.py +89 -47
 - sglang/srt/layers/logits_processor.py +6 -6
 - sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
 - sglang/srt/layers/moe/topk.py +4 -2
 - sglang/srt/layers/parameter.py +439 -0
 - sglang/srt/layers/quantization/__init__.py +5 -2
 - sglang/srt/layers/quantization/fp8.py +107 -53
 - sglang/srt/layers/quantization/fp8_utils.py +1 -1
 - sglang/srt/layers/quantization/int8_kernel.py +54 -0
 - sglang/srt/layers/quantization/modelopt_quant.py +174 -0
 - sglang/srt/layers/quantization/w8a8_int8.py +117 -0
 - sglang/srt/layers/radix_attention.py +2 -0
 - sglang/srt/layers/vocab_parallel_embedding.py +16 -3
 - sglang/srt/managers/cache_controller.py +307 -0
 - sglang/srt/managers/configure_logging.py +43 -0
 - sglang/srt/managers/data_parallel_controller.py +2 -0
 - sglang/srt/managers/detokenizer_manager.py +0 -2
 - sglang/srt/managers/io_struct.py +29 -13
 - sglang/srt/managers/schedule_batch.py +7 -1
 - sglang/srt/managers/scheduler.py +58 -15
 - sglang/srt/managers/session_controller.py +1 -1
 - sglang/srt/managers/tokenizer_manager.py +109 -45
 - sglang/srt/mem_cache/memory_pool.py +313 -53
 - sglang/srt/metrics/collector.py +32 -35
 - sglang/srt/model_executor/cuda_graph_runner.py +14 -7
 - sglang/srt/model_executor/forward_batch_info.py +20 -15
 - sglang/srt/model_executor/model_runner.py +53 -10
 - sglang/srt/models/chatglm.py +1 -1
 - sglang/srt/models/dbrx.py +1 -1
 - sglang/srt/models/grok.py +25 -16
 - sglang/srt/models/llama.py +46 -4
 - sglang/srt/models/qwen2.py +11 -0
 - sglang/srt/models/qwen2_eagle.py +131 -0
 - sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
 - sglang/srt/sampling/sampling_batch_info.py +15 -5
 - sglang/srt/sampling/sampling_params.py +1 -1
 - sglang/srt/server.py +125 -69
 - sglang/srt/server_args.py +39 -19
 - sglang/srt/speculative/eagle_utils.py +93 -85
 - sglang/srt/speculative/eagle_worker.py +48 -33
 - sglang/srt/torch_memory_saver_adapter.py +59 -0
 - sglang/srt/utils.py +61 -5
 - sglang/test/test_programs.py +23 -1
 - sglang/test/test_utils.py +36 -7
 - sglang/version.py +1 -1
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
 - {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
 
| 
         @@ -9,13 +9,12 @@ import triton.language as tl 
     | 
|
| 
       9 
9 
     | 
    
         
             
            from sglang.srt.layers.attention.flashinfer_backend import (
         
     | 
| 
       10 
10 
     | 
    
         
             
                create_flashinfer_kv_indices_triton,
         
     | 
| 
       11 
11 
     | 
    
         
             
            )
         
     | 
| 
       12 
     | 
    
         
            -
            from sglang.srt.model_executor.forward_batch_info import  
     | 
| 
      
 12 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import ForwardMode
         
     | 
| 
       13 
13 
     | 
    
         
             
            from sglang.srt.speculative.build_eagle_tree import build_tree_kernel
         
     | 
| 
       14 
14 
     | 
    
         
             
            from sglang.srt.speculative.spec_info import SpecInfo
         
     | 
| 
       15 
15 
     | 
    
         | 
| 
       16 
16 
     | 
    
         
             
            if TYPE_CHECKING:
         
     | 
| 
       17 
     | 
    
         
            -
                from  
     | 
| 
       18 
     | 
    
         
            -
                from python.sglang.srt.managers.schedule_batch import ScheduleBatch
         
     | 
| 
      
 17 
     | 
    
         
            +
                from sglang.srt.managers.schedule_batch import ScheduleBatch
         
     | 
| 
       19 
18 
     | 
    
         
             
                from sglang.srt.server_args import ServerArgs
         
     | 
| 
       20 
19 
     | 
    
         | 
| 
       21 
20 
     | 
    
         | 
| 
         @@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices( 
     | 
|
| 
       179 
178 
     | 
    
         | 
| 
       180 
179 
     | 
    
         | 
| 
       181 
180 
     | 
    
         
             
            class EAGLEDraftInput(SpecInfo):
         
     | 
| 
       182 
     | 
    
         
            -
                 
     | 
| 
       183 
     | 
    
         
            -
                verified_id: torch.Tensor = None
         
     | 
| 
       184 
     | 
    
         
            -
                positions: torch.Tensor = None
         
     | 
| 
       185 
     | 
    
         
            -
                accept_length: torch.Tensor = None
         
     | 
| 
       186 
     | 
    
         
            -
                has_finished: bool = False
         
     | 
| 
       187 
     | 
    
         
            -
                unfinished_index: List[int] = None
         
     | 
| 
       188 
     | 
    
         
            -
             
     | 
| 
       189 
     | 
    
         
            -
                def init(self, server_args: ServerArgs):
         
     | 
| 
      
 181 
     | 
    
         
            +
                def __init__(self):
         
     | 
| 
       190 
182 
     | 
    
         
             
                    self.prev_mode = ForwardMode.DECODE
         
     | 
| 
       191 
183 
     | 
    
         
             
                    self.sample_output = None
         
     | 
| 
       192 
     | 
    
         
            -
                    self.topk: int = server_args.speculative_eagle_topk
         
     | 
| 
       193 
     | 
    
         
            -
                    self.num_verify_token: int = server_args.speculative_num_draft_tokens
         
     | 
| 
       194 
     | 
    
         
            -
                    self.spec_steps = server_args.speculative_num_steps
         
     | 
| 
       195 
184 
     | 
    
         | 
| 
       196 
185 
     | 
    
         
             
                    self.scores: torch.Tensor = None
         
     | 
| 
       197 
186 
     | 
    
         
             
                    self.score_list: List[torch.Tensor] = []
         
     | 
| 
         @@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo): 
     | 
|
| 
       200 
189 
     | 
    
         
             
                    self.parents_list: List[torch.Tensor] = []
         
     | 
| 
       201 
190 
     | 
    
         
             
                    self.cache_list: List[torch.Tenor] = []
         
     | 
| 
       202 
191 
     | 
    
         
             
                    self.iter = 0
         
     | 
| 
       203 
     | 
    
         
            -
                    self.root_token: int = None
         
     | 
| 
       204 
192 
     | 
    
         | 
| 
       205 
     | 
    
         
            -
                     
     | 
| 
      
 193 
     | 
    
         
            +
                    self.hidden_states: torch.Tensor = None
         
     | 
| 
      
 194 
     | 
    
         
            +
                    self.verified_id: torch.Tensor = None
         
     | 
| 
      
 195 
     | 
    
         
            +
                    self.positions: torch.Tensor = None
         
     | 
| 
      
 196 
     | 
    
         
            +
                    self.accept_length: torch.Tensor = None
         
     | 
| 
      
 197 
     | 
    
         
            +
                    self.has_finished: bool = False
         
     | 
| 
      
 198 
     | 
    
         
            +
                    self.unfinished_index: List[int] = None
         
     | 
| 
      
 199 
     | 
    
         
            +
             
     | 
| 
      
 200 
     | 
    
         
            +
                def load_server_args(self, server_args: ServerArgs):
         
     | 
| 
      
 201 
     | 
    
         
            +
                    self.topk: int = server_args.speculative_eagle_topk
         
     | 
| 
      
 202 
     | 
    
         
            +
                    self.num_verify_token: int = server_args.speculative_num_draft_tokens
         
     | 
| 
      
 203 
     | 
    
         
            +
                    self.spec_steps = server_args.speculative_num_steps
         
     | 
| 
       206 
204 
     | 
    
         | 
| 
       207 
     | 
    
         
            -
                def prepare_for_extend(self, batch:  
     | 
| 
      
 205 
     | 
    
         
            +
                def prepare_for_extend(self, batch: ScheduleBatch):
         
     | 
| 
       208 
206 
     | 
    
         
             
                    req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
         
     | 
| 
       209 
207 
     | 
    
         
             
                    out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
         
     | 
| 
       210 
208 
     | 
    
         
             
                    batch.out_cache_loc = out_cache_loc
         
     | 
| 
         @@ -226,81 +224,73 @@ class EAGLEDraftInput(SpecInfo): 
     | 
|
| 
       226 
224 
     | 
    
         | 
| 
       227 
225 
     | 
    
         
             
                        pt += req.extend_input_len
         
     | 
| 
       228 
226 
     | 
    
         | 
| 
       229 
     | 
    
         
            -
                     
     | 
| 
       230 
     | 
    
         
            -
                     
     | 
| 
       231 
     | 
    
         
            -
                     
     | 
| 
       232 
     | 
    
         
            -
                    model_input_ids = []
         
     | 
| 
       233 
     | 
    
         
            -
                    for i in range(len(seq_lens) - 1):
         
     | 
| 
       234 
     | 
    
         
            -
                        model_input_ids.extend(
         
     | 
| 
       235 
     | 
    
         
            -
                            input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]]
         
     | 
| 
       236 
     | 
    
         
            -
                        )
         
     | 
| 
       237 
     | 
    
         
            -
                    batch.input_ids = torch.tensor(
         
     | 
| 
       238 
     | 
    
         
            -
                        model_input_ids, dtype=torch.int32, device="cuda"
         
     | 
| 
       239 
     | 
    
         
            -
                    )
         
     | 
| 
       240 
     | 
    
         
            -
             
     | 
| 
       241 
     | 
    
         
            -
                def capture_for_decode(
         
     | 
| 
       242 
     | 
    
         
            -
                    self,
         
     | 
| 
       243 
     | 
    
         
            -
                    sample_output: SampleOutput,
         
     | 
| 
       244 
     | 
    
         
            -
                    hidden_states: torch.Tensor,
         
     | 
| 
       245 
     | 
    
         
            -
                    prev_mode: ForwardMode,
         
     | 
| 
       246 
     | 
    
         
            -
                ):
         
     | 
| 
       247 
     | 
    
         
            -
                    self.sample_output = sample_output
         
     | 
| 
       248 
     | 
    
         
            -
                    self.prev_mode = prev_mode
         
     | 
| 
       249 
     | 
    
         
            -
                    self.hidden_states = hidden_states
         
     | 
| 
      
 227 
     | 
    
         
            +
                    # TODO: support batching inputs
         
     | 
| 
      
 228 
     | 
    
         
            +
                    assert len(batch.extend_lens) == 1
         
     | 
| 
      
 229 
     | 
    
         
            +
                    batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
         
     | 
| 
       250 
230 
     | 
    
         | 
| 
       251 
231 
     | 
    
         
             
                def prepare_for_decode(self, batch: ScheduleBatch):
         
     | 
| 
       252 
     | 
    
         
            -
                    prob = self.sample_output  # b * ( 
     | 
| 
      
 232 
     | 
    
         
            +
                    prob = self.sample_output  # shape: (b * top_k, vocab) or (b, vocab)
         
     | 
| 
       253 
233 
     | 
    
         
             
                    top = torch.topk(prob, self.topk, dim=-1)
         
     | 
| 
       254 
     | 
    
         
            -
                    topk_index, topk_p =  
     | 
| 
       255 
     | 
    
         
            -
             
     | 
| 
      
 234 
     | 
    
         
            +
                    topk_index, topk_p = (
         
     | 
| 
      
 235 
     | 
    
         
            +
                        top.indices,
         
     | 
| 
      
 236 
     | 
    
         
            +
                        top.values,
         
     | 
| 
      
 237 
     | 
    
         
            +
                    )  # shape: (b * top_k, top_k) or (b, top_k)
         
     | 
| 
      
 238 
     | 
    
         
            +
             
     | 
| 
      
 239 
     | 
    
         
            +
                    if self.prev_mode.is_decode():
         
     | 
| 
       256 
240 
     | 
    
         
             
                        scores = torch.mul(
         
     | 
| 
       257 
241 
     | 
    
         
             
                            self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk)
         
     | 
| 
       258 
     | 
    
         
            -
                        )  # (b, topk)  
     | 
| 
      
 242 
     | 
    
         
            +
                        )  # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
         
     | 
| 
       259 
243 
     | 
    
         
             
                        topk_cs = torch.topk(
         
     | 
| 
       260 
244 
     | 
    
         
             
                            scores.flatten(start_dim=1), self.topk, dim=-1
         
     | 
| 
       261 
245 
     | 
    
         
             
                        )  # (b, topk)
         
     | 
| 
       262 
246 
     | 
    
         
             
                        topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values
         
     | 
| 
       263 
     | 
    
         
            -
                        self.scores = topk_cs_p
         
     | 
| 
       264 
247 
     | 
    
         | 
| 
       265 
     | 
    
         
            -
                        selected_input_index = topk_cs_index.flatten() // self.topk 
     | 
| 
      
 248 
     | 
    
         
            +
                        selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange(
         
     | 
| 
      
 249 
     | 
    
         
            +
                            0, batch.batch_size() * self.topk, step=self.topk, device="cuda"
         
     | 
| 
      
 250 
     | 
    
         
            +
                        ).repeat_interleave(self.topk)
         
     | 
| 
       266 
251 
     | 
    
         | 
| 
       267 
252 
     | 
    
         
             
                        batch.spec_info.hidden_states = batch.spec_info.hidden_states[
         
     | 
| 
       268 
253 
     | 
    
         
             
                            selected_input_index, :
         
     | 
| 
       269 
254 
     | 
    
         
             
                        ]
         
     | 
| 
      
 255 
     | 
    
         
            +
             
     | 
| 
       270 
256 
     | 
    
         
             
                        topk_index = topk_index.reshape(-1, self.topk**2)
         
     | 
| 
       271 
257 
     | 
    
         
             
                        batch.input_ids = torch.gather(
         
     | 
| 
       272 
258 
     | 
    
         
             
                            topk_index, index=topk_cs_index, dim=1
         
     | 
| 
       273 
259 
     | 
    
         
             
                        ).flatten()
         
     | 
| 
       274 
     | 
    
         
            -
                        batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids 
     | 
| 
       275 
     | 
    
         
            -
             
     | 
| 
       276 
     | 
    
         
            -
                        self. 
     | 
| 
      
 260 
     | 
    
         
            +
                        batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
         
     | 
| 
      
 261 
     | 
    
         
            +
             
     | 
| 
      
 262 
     | 
    
         
            +
                        self.scores = topk_cs_p
         
     | 
| 
      
 263 
     | 
    
         
            +
                        self.score_list.append(scores)  # (b, topk, topk)
         
     | 
| 
      
 264 
     | 
    
         
            +
                        self.token_list.append(topk_index)  # (b, topk * topk)
         
     | 
| 
       277 
265 
     | 
    
         
             
                        self.origin_score_list.append(topk_p.reshape(topk_index.shape))
         
     | 
| 
       278 
266 
     | 
    
         
             
                        self.parents_list.append(
         
     | 
| 
       279 
267 
     | 
    
         
             
                            topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk)
         
     | 
| 
       280 
     | 
    
         
            -
                        )  # b, topk
         
     | 
| 
       281 
     | 
    
         
            -
             
     | 
| 
       282 
     | 
    
         
            -
             
     | 
| 
       283 
     | 
    
         
            -
                        self.scores = topk_p  # b, top_k
         
     | 
| 
       284 
     | 
    
         
            -
                        self.score_list.append(topk_p.unsqueeze(1))
         
     | 
| 
       285 
     | 
    
         
            -
                        self.token_list.append(topk_index)
         
     | 
| 
       286 
     | 
    
         
            -
                        self.origin_score_list.append(topk_p)
         
     | 
| 
      
 268 
     | 
    
         
            +
                        )  # shape: (b, topk)
         
     | 
| 
      
 269 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 270 
     | 
    
         
            +
                        # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND
         
     | 
| 
       287 
271 
     | 
    
         
             
                        batch.spec_info.hidden_states = (
         
     | 
| 
       288 
     | 
    
         
            -
                            batch.spec_info.hidden_states.repeat_interleave(self.topk, 0)
         
     | 
| 
      
 272 
     | 
    
         
            +
                            batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0)
         
     | 
| 
       289 
273 
     | 
    
         
             
                        )
         
     | 
| 
      
 274 
     | 
    
         
            +
             
     | 
| 
       290 
275 
     | 
    
         
             
                        batch.input_ids = topk_index.flatten()
         
     | 
| 
       291 
276 
     | 
    
         
             
                        batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel())
         
     | 
| 
      
 277 
     | 
    
         
            +
             
     | 
| 
      
 278 
     | 
    
         
            +
                        self.scores = topk_p  # shape: (b, topk)
         
     | 
| 
      
 279 
     | 
    
         
            +
                        self.score_list.append(topk_p.unsqueeze(1))  # shape: (b, 1, topk)
         
     | 
| 
      
 280 
     | 
    
         
            +
                        self.token_list.append(topk_index)  # shape: (b, topk)
         
     | 
| 
      
 281 
     | 
    
         
            +
                        self.origin_score_list.append(topk_p)
         
     | 
| 
       292 
282 
     | 
    
         
             
                        self.parents_list.append(
         
     | 
| 
       293 
283 
     | 
    
         
             
                            torch.arange(-1, self.topk, dtype=torch.long, device="cuda")
         
     | 
| 
       294 
284 
     | 
    
         
             
                            .unsqueeze(0)
         
     | 
| 
       295 
285 
     | 
    
         
             
                            .repeat(self.scores.shape[0], 1)
         
     | 
| 
       296 
     | 
    
         
            -
                        )  # b, topk+1
         
     | 
| 
      
 286 
     | 
    
         
            +
                        )  # shape: (b, topk + 1)
         
     | 
| 
       297 
287 
     | 
    
         
             
                    self.cache_list.append(batch.out_cache_loc)
         
     | 
| 
       298 
288 
     | 
    
         
             
                    self.positions = (
         
     | 
| 
       299 
289 
     | 
    
         
             
                        batch.seq_lens[:, None]
         
     | 
| 
       300 
290 
     | 
    
         
             
                        + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
         
     | 
| 
       301 
291 
     | 
    
         
             
                    ).flatten()
         
     | 
| 
       302 
292 
     | 
    
         | 
| 
       303 
     | 
    
         
            -
                    bs = batch.seq_lens 
     | 
| 
      
 293 
     | 
    
         
            +
                    bs = len(batch.seq_lens)
         
     | 
| 
       304 
294 
     | 
    
         
             
                    assign_req_to_token_pool[(bs,)](
         
     | 
| 
       305 
295 
     | 
    
         
             
                        batch.req_pool_indices,
         
     | 
| 
       306 
296 
     | 
    
         
             
                        batch.req_to_token_pool.req_to_token,
         
     | 
| 
         @@ -347,6 +337,7 @@ class EAGLEDraftInput(SpecInfo): 
     | 
|
| 
       347 
337 
     | 
    
         
             
                        triton.next_power_of_2(self.spec_steps + 1),
         
     | 
| 
       348 
338 
     | 
    
         
             
                    )
         
     | 
| 
       349 
339 
     | 
    
         | 
| 
      
 340 
     | 
    
         
            +
                    batch.seq_lens_sum = sum(batch.seq_lens)
         
     | 
| 
       350 
341 
     | 
    
         
             
                    batch.input_ids = self.verified_id
         
     | 
| 
       351 
342 
     | 
    
         
             
                    self.verified_id = new_verified_id
         
     | 
| 
       352 
343 
     | 
    
         | 
| 
         @@ -419,11 +410,6 @@ class EAGLEDraftInput(SpecInfo): 
     | 
|
| 
       419 
410 
     | 
    
         
             
                    )
         
     | 
| 
       420 
411 
     | 
    
         
             
                    return bs, kv_indices, cum_kv_seq_len
         
     | 
| 
       421 
412 
     | 
    
         | 
| 
       422 
     | 
    
         
            -
                def clear(self):
         
     | 
| 
       423 
     | 
    
         
            -
                    self.iter = 0
         
     | 
| 
       424 
     | 
    
         
            -
                    self.score_list.clear()
         
     | 
| 
       425 
     | 
    
         
            -
                    self.positions = None
         
     | 
| 
       426 
     | 
    
         
            -
             
     | 
| 
       427 
413 
     | 
    
         
             
                def clear_draft_cache(self, batch):
         
     | 
| 
       428 
414 
     | 
    
         
             
                    draft_cache = torch.cat(self.cache_list, dim=0)
         
     | 
| 
       429 
415 
     | 
    
         
             
                    batch.token_to_kv_pool.free(draft_cache)
         
     | 
| 
         @@ -455,12 +441,18 @@ class EAGLEDraftInput(SpecInfo): 
     | 
|
| 
       455 
441 
     | 
    
         
             
                    return kv_indices, cum_kv_seq_len, qo_indptr, None
         
     | 
| 
       456 
442 
     | 
    
         | 
| 
       457 
443 
     | 
    
         
             
                def merge_batch(self, spec_info: EAGLEDraftInput):
         
     | 
| 
       458 
     | 
    
         
            -
             
     | 
| 
      
 444 
     | 
    
         
            +
                    if self.hidden_states is None:
         
     | 
| 
      
 445 
     | 
    
         
            +
                        self.hidden_states = spec_info.hidden_states
         
     | 
| 
      
 446 
     | 
    
         
            +
                        self.verified_id = spec_info.verified_id
         
     | 
| 
      
 447 
     | 
    
         
            +
                        self.sample_output = spec_info.sample_output
         
     | 
| 
      
 448 
     | 
    
         
            +
                        self.prev_mode = spec_info.prev_mode
         
     | 
| 
      
 449 
     | 
    
         
            +
                        return
         
     | 
| 
      
 450 
     | 
    
         
            +
                    if spec_info.hidden_states is None:
         
     | 
| 
      
 451 
     | 
    
         
            +
                        return
         
     | 
| 
       459 
452 
     | 
    
         
             
                    self.hidden_states = torch.cat(
         
     | 
| 
       460 
453 
     | 
    
         
             
                        [self.hidden_states, spec_info.hidden_states], axis=0
         
     | 
| 
       461 
454 
     | 
    
         
             
                    )
         
     | 
| 
       462 
455 
     | 
    
         
             
                    self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
         
     | 
| 
       463 
     | 
    
         
            -
                    # self.positions = torch.cat([self.positions, spec_info.positions], axis=0)
         
     | 
| 
       464 
456 
     | 
    
         
             
                    self.sample_output = torch.cat([self.sample_output, spec_info.sample_output])
         
     | 
| 
       465 
457 
     | 
    
         | 
| 
       466 
458 
     | 
    
         | 
| 
         @@ -567,11 +559,37 @@ class EagleVerifyInput(SpecInfo): 
     | 
|
| 
       567 
559 
     | 
    
         
             
                        triton.next_power_of_2(max_draft_len),
         
     | 
| 
       568 
560 
     | 
    
         
             
                    )
         
     | 
| 
       569 
561 
     | 
    
         | 
| 
       570 
     | 
    
         
            -
                    accept_index = accept_index[accept_index != -1]
         
     | 
| 
       571 
     | 
    
         
            -
                    # extract_index = extract_index[extract_index != 0]
         
     | 
| 
       572 
     | 
    
         
            -
             
     | 
| 
       573 
562 
     | 
    
         
             
                    draft_input = EAGLEDraftInput()
         
     | 
| 
      
 563 
     | 
    
         
            +
                    new_accept_index = []
         
     | 
| 
      
 564 
     | 
    
         
            +
                    unfinished_index = []
         
     | 
| 
      
 565 
     | 
    
         
            +
                    finished_extend_len = {}  # {rid:accept_length + 1}
         
     | 
| 
      
 566 
     | 
    
         
            +
                    accept_index_cpu = accept_index.tolist()
         
     | 
| 
      
 567 
     | 
    
         
            +
                    predict_cpu = predict.tolist()
         
     | 
| 
      
 568 
     | 
    
         
            +
                    # iterate every accepted token and check if req has finished after append the token
         
     | 
| 
      
 569 
     | 
    
         
            +
                    # should be checked BEFORE free kv cache slots
         
     | 
| 
      
 570 
     | 
    
         
            +
                    for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
         
     | 
| 
      
 571 
     | 
    
         
            +
                        new_accept_index_ = []
         
     | 
| 
      
 572 
     | 
    
         
            +
                        for j, idx in enumerate(accept_index_row):
         
     | 
| 
      
 573 
     | 
    
         
            +
                            if idx == -1:
         
     | 
| 
      
 574 
     | 
    
         
            +
                                break
         
     | 
| 
      
 575 
     | 
    
         
            +
                            id = predict_cpu[idx]
         
     | 
| 
      
 576 
     | 
    
         
            +
                            # if not found_finished:
         
     | 
| 
      
 577 
     | 
    
         
            +
                            req.output_ids.append(id)
         
     | 
| 
      
 578 
     | 
    
         
            +
                            finished_extend_len[req.rid] = j + 1
         
     | 
| 
      
 579 
     | 
    
         
            +
                            req.check_finished()
         
     | 
| 
      
 580 
     | 
    
         
            +
                            if req.finished():
         
     | 
| 
      
 581 
     | 
    
         
            +
                                draft_input.has_finished = True
         
     | 
| 
      
 582 
     | 
    
         
            +
                                # set all tokens after finished token to -1 and break
         
     | 
| 
      
 583 
     | 
    
         
            +
                                accept_index[i, j + 1 :] = -1
         
     | 
| 
      
 584 
     | 
    
         
            +
                                break
         
     | 
| 
      
 585 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 586 
     | 
    
         
            +
                                new_accept_index_.append(idx)
         
     | 
| 
      
 587 
     | 
    
         
            +
                        if not req.finished():
         
     | 
| 
      
 588 
     | 
    
         
            +
                            new_accept_index.extend(new_accept_index_)
         
     | 
| 
      
 589 
     | 
    
         
            +
                            unfinished_index.append(i)
         
     | 
| 
      
 590 
     | 
    
         
            +
                    accept_length = (accept_index != -1).sum(dim=1) - 1
         
     | 
| 
       574 
591 
     | 
    
         | 
| 
      
 592 
     | 
    
         
            +
                    accept_index = accept_index[accept_index != -1]
         
     | 
| 
       575 
593 
     | 
    
         
             
                    accept_length_cpu = accept_length.tolist()
         
     | 
| 
       576 
594 
     | 
    
         
             
                    verified_id = predict[accept_index]
         
     | 
| 
       577 
595 
     | 
    
         
             
                    verified_id_cpu = verified_id.tolist()
         
     | 
| 
         @@ -590,29 +608,19 @@ class EagleVerifyInput(SpecInfo): 
     | 
|
| 
       590 
608 
     | 
    
         
             
                        triton.next_power_of_2(bs),
         
     | 
| 
       591 
609 
     | 
    
         
             
                    )
         
     | 
| 
       592 
610 
     | 
    
         
             
                    batch.seq_lens.add_(accept_length + 1)
         
     | 
| 
       593 
     | 
    
         
            -
                    new_accept_index = []
         
     | 
| 
       594 
     | 
    
         
            -
                    unfinished_index = []
         
     | 
| 
       595 
     | 
    
         
            -
                    finished_extend_len = {}  # {rid:accept_length + 1}
         
     | 
| 
       596 
     | 
    
         
            -
                    # retracted_reqs, new_token_ratio = batch.retract_decode()
         
     | 
| 
       597 
     | 
    
         
            -
             
     | 
| 
       598 
     | 
    
         
            -
                    low = 0
         
     | 
| 
       599 
     | 
    
         
            -
                    for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)):
         
     | 
| 
       600 
     | 
    
         
            -
                        req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1])
         
     | 
| 
       601 
     | 
    
         
            -
                        req.check_finished()
         
     | 
| 
       602 
     | 
    
         
            -
                        if req.finished():
         
     | 
| 
       603 
     | 
    
         
            -
                            draft_input.has_finished = True
         
     | 
| 
       604 
     | 
    
         
            -
                        else:
         
     | 
| 
       605 
     | 
    
         
            -
                            new_accept_index.append(accept_index[low : low + verified_len + 1])
         
     | 
| 
       606 
     | 
    
         
            -
                            unfinished_index.append(i)
         
     | 
| 
       607 
     | 
    
         
            -
                        low += verified_len + 1
         
     | 
| 
       608 
     | 
    
         
            -
                        finished_extend_len[req.rid] = verified_len + 1
         
     | 
| 
       609 
611 
     | 
    
         | 
| 
       610 
612 
     | 
    
         
             
                    if len(new_accept_index) > 0:
         
     | 
| 
       611 
     | 
    
         
            -
                        new_accept_index = torch. 
     | 
| 
      
 613 
     | 
    
         
            +
                        new_accept_index = torch.tensor(new_accept_index, device="cuda")
         
     | 
| 
       612 
614 
     | 
    
         
             
                        draft_input.verified_id = predict[new_accept_index]
         
     | 
| 
       613 
615 
     | 
    
         
             
                        draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
         
     | 
| 
       614 
616 
     | 
    
         
             
                        draft_input.accept_length = accept_length[unfinished_index]
         
     | 
| 
       615 
617 
     | 
    
         
             
                        draft_input.unfinished_index = unfinished_index
         
     | 
| 
       616 
618 
     | 
    
         | 
| 
       617 
619 
     | 
    
         
             
                    logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
         
     | 
| 
       618 
     | 
    
         
            -
                    return  
     | 
| 
      
 620 
     | 
    
         
            +
                    return (
         
     | 
| 
      
 621 
     | 
    
         
            +
                        draft_input,
         
     | 
| 
      
 622 
     | 
    
         
            +
                        logits_output,
         
     | 
| 
      
 623 
     | 
    
         
            +
                        verified_id,
         
     | 
| 
      
 624 
     | 
    
         
            +
                        finished_extend_len,
         
     | 
| 
      
 625 
     | 
    
         
            +
                        accept_length_cpu,
         
     | 
| 
      
 626 
     | 
    
         
            +
                    )
         
     | 
| 
         @@ -40,6 +40,7 @@ class EAGLEWorker(TpModelWorker): 
     | 
|
| 
       40 
40 
     | 
    
         
             
                    )
         
     | 
| 
       41 
41 
     | 
    
         
             
                    self.target_worker = target_worker
         
     | 
| 
       42 
42 
     | 
    
         
             
                    self.server_args = server_args
         
     | 
| 
      
 43 
     | 
    
         
            +
                    self.finish_extend_len = []
         
     | 
| 
       43 
44 
     | 
    
         | 
| 
       44 
45 
     | 
    
         
             
                    # Share the embedding and lm_head
         
     | 
| 
       45 
46 
     | 
    
         
             
                    embed, head = self.target_worker.model_runner.model.get_embed_and_head()
         
     | 
| 
         @@ -51,63 +52,72 @@ class EAGLEWorker(TpModelWorker): 
     | 
|
| 
       51 
52 
     | 
    
         
             
                    batch.spec_info.prepare_for_decode(batch)
         
     | 
| 
       52 
53 
     | 
    
         
             
                    model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
       53 
54 
     | 
    
         
             
                    forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
         
     | 
| 
       54 
     | 
    
         
            -
                    forward_batch. 
     | 
| 
      
 55 
     | 
    
         
            +
                    forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
         
     | 
| 
       55 
56 
     | 
    
         
             
                    logits_output = self.model_runner.forward(forward_batch)
         
     | 
| 
       56 
57 
     | 
    
         
             
                    self.capture_for_decode(logits_output, forward_batch)
         
     | 
| 
       57 
58 
     | 
    
         | 
| 
       58 
59 
     | 
    
         
             
                def forward_draft_extend(self, batch: ScheduleBatch):
         
     | 
| 
       59 
     | 
    
         
            -
                    self. 
     | 
| 
      
 60 
     | 
    
         
            +
                    self._set_mem_pool(batch, self.model_runner)
         
     | 
| 
       60 
61 
     | 
    
         
             
                    batch.spec_info.prepare_for_extend(batch)
         
     | 
| 
       61 
62 
     | 
    
         
             
                    model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
       62 
63 
     | 
    
         
             
                    forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
         
     | 
| 
       63 
     | 
    
         
            -
                    forward_batch. 
     | 
| 
      
 64 
     | 
    
         
            +
                    forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
         
     | 
| 
       64 
65 
     | 
    
         
             
                    logits_output = self.model_runner.forward(forward_batch)
         
     | 
| 
       65 
66 
     | 
    
         
             
                    self.capture_for_decode(logits_output, forward_batch)
         
     | 
| 
       66 
     | 
    
         
            -
                    self. 
     | 
| 
      
 67 
     | 
    
         
            +
                    self._set_mem_pool(batch, self.target_worker.model_runner)
         
     | 
| 
       67 
68 
     | 
    
         | 
| 
       68 
69 
     | 
    
         
             
                def forward_batch_speculative_generation(self, batch: ScheduleBatch):
         
     | 
| 
       69 
70 
     | 
    
         
             
                    if batch.forward_mode.is_decode():
         
     | 
| 
       70 
     | 
    
         
            -
                         
     | 
| 
       71 
     | 
    
         
            -
                        self. 
     | 
| 
      
 71 
     | 
    
         
            +
                        # Draft
         
     | 
| 
      
 72 
     | 
    
         
            +
                        self._set_mem_pool(batch, self.model_runner)
         
     | 
| 
       72 
73 
     | 
    
         
             
                        for i in range(self.server_args.speculative_num_steps):
         
     | 
| 
       73 
74 
     | 
    
         
             
                            self.forward_draft_decode(batch)
         
     | 
| 
       74 
75 
     | 
    
         
             
                        batch.spec_info.clear_draft_cache(batch)
         
     | 
| 
       75 
     | 
    
         
            -
                        self. 
     | 
| 
      
 76 
     | 
    
         
            +
                        self._set_mem_pool(batch, self.target_worker.model_runner)
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                        # Verify
         
     | 
| 
       76 
79 
     | 
    
         
             
                        (
         
     | 
| 
       77 
80 
     | 
    
         
             
                            next_draft_input,
         
     | 
| 
       78 
81 
     | 
    
         
             
                            logits_output,
         
     | 
| 
       79 
82 
     | 
    
         
             
                            verified_id,
         
     | 
| 
       80 
83 
     | 
    
         
             
                            self.finish_extend_len,
         
     | 
| 
      
 84 
     | 
    
         
            +
                            accept_length_cpu,
         
     | 
| 
       81 
85 
     | 
    
         
             
                            model_worker_batch,
         
     | 
| 
       82 
86 
     | 
    
         
             
                        ) = self.verify(batch)
         
     | 
| 
       83 
     | 
    
         
            -
                        next_draft_input. 
     | 
| 
      
 87 
     | 
    
         
            +
                        next_draft_input.load_server_args(self.server_args)
         
     | 
| 
       84 
88 
     | 
    
         
             
                        batch.spec_info = next_draft_input
         
     | 
| 
       85 
89 
     | 
    
         
             
                        # if it is None, means all requsets are finished
         
     | 
| 
       86 
90 
     | 
    
         
             
                        if batch.spec_info.verified_id is not None:
         
     | 
| 
       87 
     | 
    
         
            -
                            self. 
     | 
| 
       88 
     | 
    
         
            -
                         
     | 
| 
       89 
     | 
    
         
            -
             
     | 
| 
      
 91 
     | 
    
         
            +
                            self.forward_draft_extend_after_decode(batch)
         
     | 
| 
      
 92 
     | 
    
         
            +
                        return (
         
     | 
| 
      
 93 
     | 
    
         
            +
                            logits_output,
         
     | 
| 
      
 94 
     | 
    
         
            +
                            verified_id,
         
     | 
| 
      
 95 
     | 
    
         
            +
                            model_worker_batch,
         
     | 
| 
      
 96 
     | 
    
         
            +
                            sum(accept_length_cpu),
         
     | 
| 
      
 97 
     | 
    
         
            +
                        )
         
     | 
| 
       90 
98 
     | 
    
         | 
| 
       91 
99 
     | 
    
         
             
                    else:
         
     | 
| 
       92 
     | 
    
         
            -
                         
     | 
| 
       93 
     | 
    
         
            -
                         
     | 
| 
      
 100 
     | 
    
         
            +
                        # Forward with the target model and get hidden states.
         
     | 
| 
      
 101 
     | 
    
         
            +
                        # We need the full hidden states to prefill the KV cache of the draft model.
         
     | 
| 
       94 
102 
     | 
    
         
             
                        model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
       95 
     | 
    
         
            -
                        model_worker_batch. 
     | 
| 
       96 
     | 
    
         
            -
                        spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
         
     | 
| 
      
 103 
     | 
    
         
            +
                        model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
         
     | 
| 
       97 
104 
     | 
    
         
             
                        logits_output, next_token_ids = self.target_worker.forward_batch_generation(
         
     | 
| 
       98 
105 
     | 
    
         
             
                            model_worker_batch
         
     | 
| 
       99 
106 
     | 
    
         
             
                        )
         
     | 
| 
       100 
     | 
    
         
            -
             
     | 
| 
       101 
     | 
    
         
            -
                         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                        # Forward with the draft model.
         
     | 
| 
      
 109 
     | 
    
         
            +
                        spec_info = EAGLEDraftInput()
         
     | 
| 
      
 110 
     | 
    
         
            +
                        spec_info.load_server_args(self.server_args)
         
     | 
| 
      
 111 
     | 
    
         
            +
                        spec_info.hidden_states = logits_output.hidden_states
         
     | 
| 
      
 112 
     | 
    
         
            +
                        spec_info.verified_id = next_token_ids
         
     | 
| 
       102 
113 
     | 
    
         
             
                        batch.spec_info = spec_info
         
     | 
| 
       103 
114 
     | 
    
         
             
                        self.forward_draft_extend(batch)
         
     | 
| 
       104 
     | 
    
         
            -
                         
     | 
| 
       105 
     | 
    
         
            -
                        return logits_output, next_token_ids, model_worker_batch, spec_info
         
     | 
| 
      
 115 
     | 
    
         
            +
                        return logits_output, next_token_ids, model_worker_batch, 0
         
     | 
| 
       106 
116 
     | 
    
         | 
| 
       107 
117 
     | 
    
         
             
                def verify(self, batch: ScheduleBatch):
         
     | 
| 
       108 
118 
     | 
    
         
             
                    verify_input = batch.spec_info.prepare_for_verify(batch)
         
     | 
| 
       109 
     | 
    
         
            -
                    batch.forward_mode = ForwardMode.TARGET_VERIFY
         
     | 
| 
       110 
119 
     | 
    
         
             
                    verify_input.prepare_for_verify(batch)
         
     | 
| 
      
 120 
     | 
    
         
            +
                    batch.forward_mode = ForwardMode.TARGET_VERIFY
         
     | 
| 
       111 
121 
     | 
    
         
             
                    batch.spec_info = verify_input
         
     | 
| 
       112 
122 
     | 
    
         
             
                    batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL
         
     | 
| 
       113 
123 
     | 
    
         
             
                    model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
         @@ -119,44 +129,49 @@ class EAGLEWorker(TpModelWorker): 
     | 
|
| 
       119 
129 
     | 
    
         
             
                    batch.forward_mode = ForwardMode.DECODE
         
     | 
| 
       120 
130 
     | 
    
         
             
                    return res + (model_worker_batch,)
         
     | 
| 
       121 
131 
     | 
    
         | 
| 
       122 
     | 
    
         
            -
                def  
     | 
| 
      
 132 
     | 
    
         
            +
                def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
         
     | 
| 
       123 
133 
     | 
    
         
             
                    batch.token_to_kv_pool = runner.token_to_kv_pool
         
     | 
| 
       124 
134 
     | 
    
         
             
                    batch.req_to_token_pool = runner.req_to_token_pool
         
     | 
| 
       125 
135 
     | 
    
         | 
| 
       126 
     | 
    
         
            -
                def  
     | 
| 
       127 
     | 
    
         
            -
                    self. 
     | 
| 
      
 136 
     | 
    
         
            +
                def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
         
     | 
| 
      
 137 
     | 
    
         
            +
                    self._set_mem_pool(batch, self.model_runner)
         
     | 
| 
       128 
138 
     | 
    
         
             
                    batch.forward_mode = ForwardMode.DRAFT_EXTEND
         
     | 
| 
       129 
139 
     | 
    
         
             
                    if batch.spec_info.has_finished:
         
     | 
| 
       130 
140 
     | 
    
         
             
                        index = batch.spec_info.unfinished_index
         
     | 
| 
       131 
141 
     | 
    
         
             
                        seq_lens = batch.seq_lens
         
     | 
| 
       132 
142 
     | 
    
         
             
                        batch.seq_lens = batch.seq_lens[index]
         
     | 
| 
      
 143 
     | 
    
         
            +
             
     | 
| 
       133 
144 
     | 
    
         
             
                    batch.spec_info.prepare_extend_after_decode(batch)
         
     | 
| 
       134 
145 
     | 
    
         
             
                    model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
       135 
146 
     | 
    
         
             
                    forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
         
     | 
| 
       136 
     | 
    
         
            -
                    forward_batch. 
     | 
| 
      
 147 
     | 
    
         
            +
                    forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
         
     | 
| 
       137 
148 
     | 
    
         
             
                    logits_output = self.model_runner.forward(forward_batch)
         
     | 
| 
      
 149 
     | 
    
         
            +
             
     | 
| 
       138 
150 
     | 
    
         
             
                    batch.spec_info.hidden_states = logits_output.hidden_states
         
     | 
| 
       139 
151 
     | 
    
         
             
                    self.capture_for_decode(logits_output, forward_batch)
         
     | 
| 
       140 
152 
     | 
    
         
             
                    batch.forward_mode = ForwardMode.DECODE
         
     | 
| 
       141 
153 
     | 
    
         
             
                    if batch.spec_info.has_finished:
         
     | 
| 
       142 
154 
     | 
    
         
             
                        batch.seq_lens = seq_lens
         
     | 
| 
       143 
     | 
    
         
            -
                    self. 
     | 
| 
      
 155 
     | 
    
         
            +
                    self._set_mem_pool(batch, self.target_worker.model_runner)
         
     | 
| 
       144 
156 
     | 
    
         | 
| 
       145 
     | 
    
         
            -
                def capture_for_decode( 
     | 
| 
       146 
     | 
    
         
            -
                     
     | 
| 
       147 
     | 
    
         
            -
             
     | 
| 
      
 157 
     | 
    
         
            +
                def capture_for_decode(
         
     | 
| 
      
 158 
     | 
    
         
            +
                    self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
         
     | 
| 
      
 159 
     | 
    
         
            +
                ):
         
     | 
| 
       148 
160 
     | 
    
         
             
                    sample_output = torch.softmax(
         
     | 
| 
       149 
     | 
    
         
            -
                         
     | 
| 
       150 
     | 
    
         
            -
                    )  # TODO: Support more sampling  
     | 
| 
       151 
     | 
    
         
            -
                    forward_batch.spec_info 
     | 
| 
       152 
     | 
    
         
            -
             
     | 
| 
       153 
     | 
    
         
            -
                     
     | 
| 
      
 161 
     | 
    
         
            +
                        logits_output.next_token_logits, dim=-1
         
     | 
| 
      
 162 
     | 
    
         
            +
                    )  # TODO(kavioyu): Support more sampling methods
         
     | 
| 
      
 163 
     | 
    
         
            +
                    spec_info = forward_batch.spec_info
         
     | 
| 
      
 164 
     | 
    
         
            +
                    spec_info.sample_output = sample_output
         
     | 
| 
      
 165 
     | 
    
         
            +
                    spec_info.hidden_states = logits_output.hidden_states
         
     | 
| 
      
 166 
     | 
    
         
            +
                    spec_info.prev_mode = forward_batch.forward_mode
         
     | 
| 
       154 
167 
     | 
    
         | 
| 
       155 
168 
     | 
    
         
             
                # Don't support prefix share now.
         
     | 
| 
       156 
169 
     | 
    
         
             
                def finish_request(self, reqs: Union[Req, List[Req]]):
         
     | 
| 
       157 
170 
     | 
    
         
             
                    if not isinstance(reqs, List):
         
     | 
| 
       158 
171 
     | 
    
         
             
                        reqs = [reqs]
         
     | 
| 
       159 
172 
     | 
    
         
             
                    for req in reqs:
         
     | 
| 
      
 173 
     | 
    
         
            +
                        if req.rid not in self.finish_extend_len:
         
     | 
| 
      
 174 
     | 
    
         
            +
                            continue
         
     | 
| 
       160 
175 
     | 
    
         
             
                        req_len = (
         
     | 
| 
       161 
176 
     | 
    
         
             
                            len(req.origin_input_ids)
         
     | 
| 
       162 
177 
     | 
    
         
             
                            + len(req.output_ids)
         
     | 
| 
         @@ -0,0 +1,59 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from abc import ABC
         
     | 
| 
      
 2 
     | 
    
         
            +
            from contextlib import contextmanager
         
     | 
| 
      
 3 
     | 
    
         
            +
             
     | 
| 
      
 4 
     | 
    
         
            +
            try:
         
     | 
| 
      
 5 
     | 
    
         
            +
                import torch_memory_saver
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
                _primary_memory_saver = torch_memory_saver.TorchMemorySaver()
         
     | 
| 
      
 8 
     | 
    
         
            +
            except ImportError:
         
     | 
| 
      
 9 
     | 
    
         
            +
                pass
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            class TorchMemorySaverAdapter(ABC):
         
     | 
| 
      
 13 
     | 
    
         
            +
                @staticmethod
         
     | 
| 
      
 14 
     | 
    
         
            +
                def create(enable: bool):
         
     | 
| 
      
 15 
     | 
    
         
            +
                    return (
         
     | 
| 
      
 16 
     | 
    
         
            +
                        _TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
         
     | 
| 
      
 17 
     | 
    
         
            +
                    )
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
                def configure_subprocess(self):
         
     | 
| 
      
 20 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 21 
     | 
    
         
            +
             
     | 
| 
      
 22 
     | 
    
         
            +
                def region(self):
         
     | 
| 
      
 23 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
                def pause(self):
         
     | 
| 
      
 26 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
                def resume(self):
         
     | 
| 
      
 29 
     | 
    
         
            +
                    raise NotImplementedError
         
     | 
| 
      
 30 
     | 
    
         
            +
             
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
         
     | 
| 
      
 33 
     | 
    
         
            +
                def configure_subprocess(self):
         
     | 
| 
      
 34 
     | 
    
         
            +
                    return torch_memory_saver.configure_subprocess()
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                def region(self):
         
     | 
| 
      
 37 
     | 
    
         
            +
                    return _primary_memory_saver.region()
         
     | 
| 
      
 38 
     | 
    
         
            +
             
     | 
| 
      
 39 
     | 
    
         
            +
                def pause(self):
         
     | 
| 
      
 40 
     | 
    
         
            +
                    return _primary_memory_saver.pause()
         
     | 
| 
      
 41 
     | 
    
         
            +
             
     | 
| 
      
 42 
     | 
    
         
            +
                def resume(self):
         
     | 
| 
      
 43 
     | 
    
         
            +
                    return _primary_memory_saver.resume()
         
     | 
| 
      
 44 
     | 
    
         
            +
             
     | 
| 
      
 45 
     | 
    
         
            +
             
     | 
| 
      
 46 
     | 
    
         
            +
            class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
         
     | 
| 
      
 47 
     | 
    
         
            +
                @contextmanager
         
     | 
| 
      
 48 
     | 
    
         
            +
                def configure_subprocess(self):
         
     | 
| 
      
 49 
     | 
    
         
            +
                    yield
         
     | 
| 
      
 50 
     | 
    
         
            +
             
     | 
| 
      
 51 
     | 
    
         
            +
                @contextmanager
         
     | 
| 
      
 52 
     | 
    
         
            +
                def region(self):
         
     | 
| 
      
 53 
     | 
    
         
            +
                    yield
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                def pause(self):
         
     | 
| 
      
 56 
     | 
    
         
            +
                    pass
         
     | 
| 
      
 57 
     | 
    
         
            +
             
     | 
| 
      
 58 
     | 
    
         
            +
                def resume(self):
         
     | 
| 
      
 59 
     | 
    
         
            +
                    pass
         
     | 
    
        sglang/srt/utils.py
    CHANGED
    
    | 
         @@ -97,6 +97,10 @@ def is_flashinfer_available(): 
     | 
|
| 
       97 
97 
     | 
    
         
             
                return torch.cuda.is_available() and torch.version.cuda
         
     | 
| 
       98 
98 
     | 
    
         | 
| 
       99 
99 
     | 
    
         | 
| 
      
 100 
     | 
    
         
            +
            def is_cuda_available():
         
     | 
| 
      
 101 
     | 
    
         
            +
                return torch.cuda.is_available() and torch.version.cuda
         
     | 
| 
      
 102 
     | 
    
         
            +
             
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
       100 
104 
     | 
    
         
             
            def is_ipv6(address):
         
     | 
| 
       101 
105 
     | 
    
         
             
                try:
         
     | 
| 
       102 
106 
     | 
    
         
             
                    ipaddress.IPv6Address(address)
         
     | 
| 
         @@ -335,6 +339,8 @@ def is_port_available(port): 
     | 
|
| 
       335 
339 
     | 
    
         
             
                        return True
         
     | 
| 
       336 
340 
     | 
    
         
             
                    except socket.error:
         
     | 
| 
       337 
341 
     | 
    
         
             
                        return False
         
     | 
| 
      
 342 
     | 
    
         
            +
                    except OverflowError:
         
     | 
| 
      
 343 
     | 
    
         
            +
                        return False
         
     | 
| 
       338 
344 
     | 
    
         | 
| 
       339 
345 
     | 
    
         | 
| 
       340 
346 
     | 
    
         
             
            def decode_video_base64(video_base64):
         
     | 
| 
         @@ -709,13 +715,14 @@ def broadcast_pyobj( 
     | 
|
| 
       709 
715 
     | 
    
         
             
                data: List[Any],
         
     | 
| 
       710 
716 
     | 
    
         
             
                rank: int,
         
     | 
| 
       711 
717 
     | 
    
         
             
                dist_group: Optional[torch.distributed.ProcessGroup] = None,
         
     | 
| 
      
 718 
     | 
    
         
            +
                src: int = 0,
         
     | 
| 
       712 
719 
     | 
    
         
             
            ):
         
     | 
| 
       713 
720 
     | 
    
         
             
                """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
         
     | 
| 
       714 
721 
     | 
    
         | 
| 
       715 
722 
     | 
    
         
             
                if rank == 0:
         
     | 
| 
       716 
723 
     | 
    
         
             
                    if len(data) == 0:
         
     | 
| 
       717 
724 
     | 
    
         
             
                        tensor_size = torch.tensor([0], dtype=torch.long)
         
     | 
| 
       718 
     | 
    
         
            -
                        dist.broadcast(tensor_size, src= 
     | 
| 
      
 725 
     | 
    
         
            +
                        dist.broadcast(tensor_size, src=src, group=dist_group)
         
     | 
| 
       719 
726 
     | 
    
         
             
                    else:
         
     | 
| 
       720 
727 
     | 
    
         
             
                        serialized_data = pickle.dumps(data)
         
     | 
| 
       721 
728 
     | 
    
         
             
                        size = len(serialized_data)
         
     | 
| 
         @@ -724,19 +731,19 @@ def broadcast_pyobj( 
     | 
|
| 
       724 
731 
     | 
    
         
             
                        )
         
     | 
| 
       725 
732 
     | 
    
         
             
                        tensor_size = torch.tensor([size], dtype=torch.long)
         
     | 
| 
       726 
733 
     | 
    
         | 
| 
       727 
     | 
    
         
            -
                        dist.broadcast(tensor_size, src= 
     | 
| 
       728 
     | 
    
         
            -
                        dist.broadcast(tensor_data, src= 
     | 
| 
      
 734 
     | 
    
         
            +
                        dist.broadcast(tensor_size, src=src, group=dist_group)
         
     | 
| 
      
 735 
     | 
    
         
            +
                        dist.broadcast(tensor_data, src=src, group=dist_group)
         
     | 
| 
       729 
736 
     | 
    
         
             
                    return data
         
     | 
| 
       730 
737 
     | 
    
         
             
                else:
         
     | 
| 
       731 
738 
     | 
    
         
             
                    tensor_size = torch.tensor([0], dtype=torch.long)
         
     | 
| 
       732 
     | 
    
         
            -
                    dist.broadcast(tensor_size, src= 
     | 
| 
      
 739 
     | 
    
         
            +
                    dist.broadcast(tensor_size, src=src, group=dist_group)
         
     | 
| 
       733 
740 
     | 
    
         
             
                    size = tensor_size.item()
         
     | 
| 
       734 
741 
     | 
    
         | 
| 
       735 
742 
     | 
    
         
             
                    if size == 0:
         
     | 
| 
       736 
743 
     | 
    
         
             
                        return []
         
     | 
| 
       737 
744 
     | 
    
         | 
| 
       738 
745 
     | 
    
         
             
                    tensor_data = torch.empty(size, dtype=torch.uint8)
         
     | 
| 
       739 
     | 
    
         
            -
                    dist.broadcast(tensor_data, src= 
     | 
| 
      
 746 
     | 
    
         
            +
                    dist.broadcast(tensor_data, src=src, group=dist_group)
         
     | 
| 
       740 
747 
     | 
    
         | 
| 
       741 
748 
     | 
    
         
             
                    serialized_data = bytes(tensor_data.cpu().numpy())
         
     | 
| 
       742 
749 
     | 
    
         
             
                    data = pickle.loads(serialized_data)
         
     | 
| 
         @@ -1337,6 +1344,25 @@ def parse_tool_response(text, tools, **kwargs): 
     | 
|
| 
       1337 
1344 
     | 
    
         
             
                return text, call_info_list
         
     | 
| 
       1338 
1345 
     | 
    
         | 
| 
       1339 
1346 
     | 
    
         | 
| 
      
 1347 
     | 
    
         
            +
            def permute_weight(x: torch.Tensor) -> torch.Tensor:
         
     | 
| 
      
 1348 
     | 
    
         
            +
                b_ = x.shape[0]
         
     | 
| 
      
 1349 
     | 
    
         
            +
                n_ = x.shape[1]
         
     | 
| 
      
 1350 
     | 
    
         
            +
                k_ = x.shape[2]
         
     | 
| 
      
 1351 
     | 
    
         
            +
             
     | 
| 
      
 1352 
     | 
    
         
            +
                x_ = x
         
     | 
| 
      
 1353 
     | 
    
         
            +
                if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
         
     | 
| 
      
 1354 
     | 
    
         
            +
                    x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
         
     | 
| 
      
 1355 
     | 
    
         
            +
                elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
         
     | 
| 
      
 1356 
     | 
    
         
            +
                    x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
         
     | 
| 
      
 1357 
     | 
    
         
            +
                else:
         
     | 
| 
      
 1358 
     | 
    
         
            +
                    return x_
         
     | 
| 
      
 1359 
     | 
    
         
            +
             
     | 
| 
      
 1360 
     | 
    
         
            +
                x_ = x_.permute(0, 1, 3, 4, 2, 5)
         
     | 
| 
      
 1361 
     | 
    
         
            +
                x_ = x_.contiguous()
         
     | 
| 
      
 1362 
     | 
    
         
            +
                x_ = x_.view(*x.shape)
         
     | 
| 
      
 1363 
     | 
    
         
            +
                return x_
         
     | 
| 
      
 1364 
     | 
    
         
            +
             
     | 
| 
      
 1365 
     | 
    
         
            +
             
     | 
| 
       1340 
1366 
     | 
    
         
             
            class MultiprocessingSerializer:
         
     | 
| 
       1341 
1367 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       1342 
1368 
     | 
    
         
             
                def serialize(obj):
         
     | 
| 
         @@ -1348,3 +1374,33 @@ class MultiprocessingSerializer: 
     | 
|
| 
       1348 
1374 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       1349 
1375 
     | 
    
         
             
                def deserialize(data):
         
     | 
| 
       1350 
1376 
     | 
    
         
             
                    return ForkingPickler.loads(data)
         
     | 
| 
      
 1377 
     | 
    
         
            +
             
     | 
| 
      
 1378 
     | 
    
         
            +
             
     | 
| 
      
 1379 
     | 
    
         
            +
            def debug_timing(func):
         
     | 
| 
      
 1380 
     | 
    
         
            +
                # todo: replace with a more organized instrumentation
         
     | 
| 
      
 1381 
     | 
    
         
            +
                def wrapper(*args, **kwargs):
         
     | 
| 
      
 1382 
     | 
    
         
            +
                    if logger.isEnabledFor(logging.DEBUG):
         
     | 
| 
      
 1383 
     | 
    
         
            +
                        tic = torch.cuda.Event(enable_timing=True)
         
     | 
| 
      
 1384 
     | 
    
         
            +
                        toc = torch.cuda.Event(enable_timing=True)
         
     | 
| 
      
 1385 
     | 
    
         
            +
                        tic.record()
         
     | 
| 
      
 1386 
     | 
    
         
            +
                        result = func(*args, **kwargs)
         
     | 
| 
      
 1387 
     | 
    
         
            +
                        toc.record()
         
     | 
| 
      
 1388 
     | 
    
         
            +
                        torch.cuda.synchronize()  # Ensure all CUDA operations are complete
         
     | 
| 
      
 1389 
     | 
    
         
            +
                        elapsed = tic.elapsed_time(toc)
         
     | 
| 
      
 1390 
     | 
    
         
            +
                        indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
         
     | 
| 
      
 1391 
     | 
    
         
            +
                        num_tokens = len(indices) if indices is not None else 0
         
     | 
| 
      
 1392 
     | 
    
         
            +
                        throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0
         
     | 
| 
      
 1393 
     | 
    
         
            +
                        logger.debug(
         
     | 
| 
      
 1394 
     | 
    
         
            +
                            f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s"
         
     | 
| 
      
 1395 
     | 
    
         
            +
                        )
         
     | 
| 
      
 1396 
     | 
    
         
            +
                        return result
         
     | 
| 
      
 1397 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1398 
     | 
    
         
            +
                        return func(*args, **kwargs)
         
     | 
| 
      
 1399 
     | 
    
         
            +
             
     | 
| 
      
 1400 
     | 
    
         
            +
                return wrapper
         
     | 
| 
      
 1401 
     | 
    
         
            +
             
     | 
| 
      
 1402 
     | 
    
         
            +
             
     | 
| 
      
 1403 
     | 
    
         
            +
            def nullable_str(val: str):
         
     | 
| 
      
 1404 
     | 
    
         
            +
                if not val or val == "None":
         
     | 
| 
      
 1405 
     | 
    
         
            +
                    return None
         
     | 
| 
      
 1406 
     | 
    
         
            +
                return val
         
     | 
    
        sglang/test/test_programs.py
    CHANGED
    
    | 
         @@ -509,13 +509,35 @@ def test_hellaswag_select(): 
     | 
|
| 
       509 
509 
     | 
    
         
             
                    temperature=0,
         
     | 
| 
       510 
510 
     | 
    
         
             
                    num_threads=64,
         
     | 
| 
       511 
511 
     | 
    
         
             
                    progress_bar=True,
         
     | 
| 
      
 512 
     | 
    
         
            +
                    generator_style=False,
         
     | 
| 
       512 
513 
     | 
    
         
             
                )
         
     | 
| 
       513 
     | 
    
         
            -
                preds = [ 
     | 
| 
      
 514 
     | 
    
         
            +
                preds = []
         
     | 
| 
      
 515 
     | 
    
         
            +
                for i, ret in enumerate(rets):
         
     | 
| 
      
 516 
     | 
    
         
            +
                    preds.append(choices[i].index(ret["answer"]))
         
     | 
| 
       514 
517 
     | 
    
         
             
                latency = time.time() - tic
         
     | 
| 
       515 
518 
     | 
    
         | 
| 
       516 
519 
     | 
    
         
             
                # Compute accuracy
         
     | 
| 
       517 
520 
     | 
    
         
             
                accuracy = np.mean(np.array(preds) == np.array(labels))
         
     | 
| 
       518 
521 
     | 
    
         | 
| 
      
 522 
     | 
    
         
            +
                # Test generator style of run_batch
         
     | 
| 
      
 523 
     | 
    
         
            +
                tic = time.time()
         
     | 
| 
      
 524 
     | 
    
         
            +
                rets = few_shot_hellaswag.run_batch(
         
     | 
| 
      
 525 
     | 
    
         
            +
                    arguments,
         
     | 
| 
      
 526 
     | 
    
         
            +
                    temperature=0,
         
     | 
| 
      
 527 
     | 
    
         
            +
                    num_threads=64,
         
     | 
| 
      
 528 
     | 
    
         
            +
                    progress_bar=True,
         
     | 
| 
      
 529 
     | 
    
         
            +
                    generator_style=True,
         
     | 
| 
      
 530 
     | 
    
         
            +
                )
         
     | 
| 
      
 531 
     | 
    
         
            +
                preds_gen = []
         
     | 
| 
      
 532 
     | 
    
         
            +
                for i, ret in enumerate(rets):
         
     | 
| 
      
 533 
     | 
    
         
            +
                    preds_gen.append(choices[i].index(ret["answer"]))
         
     | 
| 
      
 534 
     | 
    
         
            +
                latency_gen = time.time() - tic
         
     | 
| 
      
 535 
     | 
    
         
            +
             
     | 
| 
      
 536 
     | 
    
         
            +
                # Compute accuracy
         
     | 
| 
      
 537 
     | 
    
         
            +
                accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
         
     | 
| 
      
 538 
     | 
    
         
            +
                assert np.abs(accuracy_gen - accuracy) < 0.01
         
     | 
| 
      
 539 
     | 
    
         
            +
                assert np.abs(latency_gen - latency) < 1
         
     | 
| 
      
 540 
     | 
    
         
            +
             
     | 
| 
       519 
541 
     | 
    
         
             
                return accuracy, latency
         
     | 
| 
       520 
542 
     | 
    
         | 
| 
       521 
543 
     | 
    
         |