sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +56 -12
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/model_config.py +5 -5
 - sglang/srt/distributed/parallel_state.py +0 -7
 - sglang/srt/entrypoints/engine.py +18 -15
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +75 -94
 - sglang/srt/environ.py +16 -2
 - sglang/srt/eplb/expert_distribution.py +30 -0
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/flashattention_backend.py +12 -2
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
 - sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +1 -0
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +19 -4
 - sglang/srt/layers/logits_processor.py +5 -0
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +79 -272
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +4 -4
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/quantization/__init__.py +3 -5
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +13 -1
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/managers/io_struct.py +3 -0
 - sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
 - sglang/srt/managers/scheduler.py +21 -15
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/tokenizer_manager.py +11 -19
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +82 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/model_executor/forward_batch_info.py +44 -3
 - sglang/srt/model_executor/model_runner.py +1 -149
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/models/deepseek_v2.py +147 -44
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v_moe.py +29 -196
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +22 -1
 - sglang/srt/models/qwen3.py +34 -4
 - sglang/srt/models/qwen3_moe.py +2 -4
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/server_args.py +365 -186
 - sglang/srt/single_batch_overlap.py +2 -7
 - sglang/srt/utils/common.py +87 -42
 - sglang/srt/utils/hf_transformers_utils.py +7 -3
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
 - sglang/srt/models/vila.py +0 -306
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
 
| 
         @@ -1213,6 +1213,65 @@ def set_mla_kv_buffer_triton( 
     | 
|
| 
       1213 
1213 
     | 
    
         
             
                )
         
     | 
| 
       1214 
1214 
     | 
    
         | 
| 
       1215 
1215 
     | 
    
         | 
| 
      
 1216 
     | 
    
         
            +
            @triton.jit
         
     | 
| 
      
 1217 
     | 
    
         
            +
            def get_mla_kv_buffer_kernel(
         
     | 
| 
      
 1218 
     | 
    
         
            +
                kv_buffer_ptr,
         
     | 
| 
      
 1219 
     | 
    
         
            +
                cache_k_nope_ptr,
         
     | 
| 
      
 1220 
     | 
    
         
            +
                cache_k_rope_ptr,
         
     | 
| 
      
 1221 
     | 
    
         
            +
                loc_ptr,
         
     | 
| 
      
 1222 
     | 
    
         
            +
                buffer_stride: tl.constexpr,
         
     | 
| 
      
 1223 
     | 
    
         
            +
                nope_stride: tl.constexpr,
         
     | 
| 
      
 1224 
     | 
    
         
            +
                rope_stride: tl.constexpr,
         
     | 
| 
      
 1225 
     | 
    
         
            +
                nope_dim: tl.constexpr,
         
     | 
| 
      
 1226 
     | 
    
         
            +
                rope_dim: tl.constexpr,
         
     | 
| 
      
 1227 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1228 
     | 
    
         
            +
                pid_loc = tl.program_id(0)
         
     | 
| 
      
 1229 
     | 
    
         
            +
                loc = tl.load(loc_ptr + pid_loc)
         
     | 
| 
      
 1230 
     | 
    
         
            +
                loc_src_ptr = kv_buffer_ptr + loc * buffer_stride
         
     | 
| 
      
 1231 
     | 
    
         
            +
             
     | 
| 
      
 1232 
     | 
    
         
            +
                nope_offs = tl.arange(0, nope_dim)
         
     | 
| 
      
 1233 
     | 
    
         
            +
                nope_src_ptr = loc_src_ptr + nope_offs
         
     | 
| 
      
 1234 
     | 
    
         
            +
                nope_src = tl.load(nope_src_ptr)
         
     | 
| 
      
 1235 
     | 
    
         
            +
             
     | 
| 
      
 1236 
     | 
    
         
            +
                tl.store(
         
     | 
| 
      
 1237 
     | 
    
         
            +
                    cache_k_nope_ptr + pid_loc * nope_stride + nope_offs,
         
     | 
| 
      
 1238 
     | 
    
         
            +
                    nope_src,
         
     | 
| 
      
 1239 
     | 
    
         
            +
                )
         
     | 
| 
      
 1240 
     | 
    
         
            +
             
     | 
| 
      
 1241 
     | 
    
         
            +
                rope_offs = tl.arange(0, rope_dim)
         
     | 
| 
      
 1242 
     | 
    
         
            +
                rope_src_ptr = loc_src_ptr + nope_dim + rope_offs
         
     | 
| 
      
 1243 
     | 
    
         
            +
                rope_src = tl.load(rope_src_ptr)
         
     | 
| 
      
 1244 
     | 
    
         
            +
                tl.store(
         
     | 
| 
      
 1245 
     | 
    
         
            +
                    cache_k_rope_ptr + pid_loc * rope_stride + rope_offs,
         
     | 
| 
      
 1246 
     | 
    
         
            +
                    rope_src,
         
     | 
| 
      
 1247 
     | 
    
         
            +
                )
         
     | 
| 
      
 1248 
     | 
    
         
            +
             
     | 
| 
      
 1249 
     | 
    
         
            +
             
     | 
| 
      
 1250 
     | 
    
         
            +
            def get_mla_kv_buffer_triton(
         
     | 
| 
      
 1251 
     | 
    
         
            +
                kv_buffer: torch.Tensor,
         
     | 
| 
      
 1252 
     | 
    
         
            +
                loc: torch.Tensor,
         
     | 
| 
      
 1253 
     | 
    
         
            +
                cache_k_nope: torch.Tensor,
         
     | 
| 
      
 1254 
     | 
    
         
            +
                cache_k_rope: torch.Tensor,
         
     | 
| 
      
 1255 
     | 
    
         
            +
            ):
         
     | 
| 
      
 1256 
     | 
    
         
            +
                # The source data type will be implicitly converted to the target data type.
         
     | 
| 
      
 1257 
     | 
    
         
            +
                nope_dim = cache_k_nope.shape[-1]  # 512
         
     | 
| 
      
 1258 
     | 
    
         
            +
                rope_dim = cache_k_rope.shape[-1]  # 64
         
     | 
| 
      
 1259 
     | 
    
         
            +
                n_loc = loc.numel()
         
     | 
| 
      
 1260 
     | 
    
         
            +
                grid = (n_loc,)
         
     | 
| 
      
 1261 
     | 
    
         
            +
             
     | 
| 
      
 1262 
     | 
    
         
            +
                get_mla_kv_buffer_kernel[grid](
         
     | 
| 
      
 1263 
     | 
    
         
            +
                    kv_buffer,
         
     | 
| 
      
 1264 
     | 
    
         
            +
                    cache_k_nope,
         
     | 
| 
      
 1265 
     | 
    
         
            +
                    cache_k_rope,
         
     | 
| 
      
 1266 
     | 
    
         
            +
                    loc,
         
     | 
| 
      
 1267 
     | 
    
         
            +
                    kv_buffer.stride(0),
         
     | 
| 
      
 1268 
     | 
    
         
            +
                    cache_k_nope.stride(0),
         
     | 
| 
      
 1269 
     | 
    
         
            +
                    cache_k_rope.stride(0),
         
     | 
| 
      
 1270 
     | 
    
         
            +
                    nope_dim,
         
     | 
| 
      
 1271 
     | 
    
         
            +
                    rope_dim,
         
     | 
| 
      
 1272 
     | 
    
         
            +
                )
         
     | 
| 
      
 1273 
     | 
    
         
            +
             
     | 
| 
      
 1274 
     | 
    
         
            +
             
     | 
| 
       1216 
1275 
     | 
    
         
             
            class MLATokenToKVPool(KVCache):
         
     | 
| 
       1217 
1276 
     | 
    
         
             
                def __init__(
         
     | 
| 
       1218 
1277 
     | 
    
         
             
                    self,
         
     | 
| 
         @@ -1363,6 +1422,29 @@ class MLATokenToKVPool(KVCache): 
     | 
|
| 
       1363 
1422 
     | 
    
         
             
                            cache_k_rope,
         
     | 
| 
       1364 
1423 
     | 
    
         
             
                        )
         
     | 
| 
       1365 
1424 
     | 
    
         | 
| 
      
 1425 
     | 
    
         
            +
                def get_mla_kv_buffer(
         
     | 
| 
      
 1426 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 1427 
     | 
    
         
            +
                    layer: RadixAttention,
         
     | 
| 
      
 1428 
     | 
    
         
            +
                    loc: torch.Tensor,
         
     | 
| 
      
 1429 
     | 
    
         
            +
                    dst_dtype: Optional[torch.dtype] = None,
         
     | 
| 
      
 1430 
     | 
    
         
            +
                ):
         
     | 
| 
      
 1431 
     | 
    
         
            +
                    # get k nope and k rope from the kv buffer, and optionally cast them to dst_dtype.
         
     | 
| 
      
 1432 
     | 
    
         
            +
                    layer_id = layer.layer_id
         
     | 
| 
      
 1433 
     | 
    
         
            +
                    kv_buffer = self.get_key_buffer(layer_id)
         
     | 
| 
      
 1434 
     | 
    
         
            +
                    dst_dtype = dst_dtype or self.dtype
         
     | 
| 
      
 1435 
     | 
    
         
            +
                    cache_k_nope = torch.empty(
         
     | 
| 
      
 1436 
     | 
    
         
            +
                        (loc.shape[0], 1, self.kv_lora_rank),
         
     | 
| 
      
 1437 
     | 
    
         
            +
                        dtype=dst_dtype,
         
     | 
| 
      
 1438 
     | 
    
         
            +
                        device=kv_buffer.device,
         
     | 
| 
      
 1439 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1440 
     | 
    
         
            +
                    cache_k_rope = torch.empty(
         
     | 
| 
      
 1441 
     | 
    
         
            +
                        (loc.shape[0], 1, self.qk_rope_head_dim),
         
     | 
| 
      
 1442 
     | 
    
         
            +
                        dtype=dst_dtype,
         
     | 
| 
      
 1443 
     | 
    
         
            +
                        device=kv_buffer.device,
         
     | 
| 
      
 1444 
     | 
    
         
            +
                    )
         
     | 
| 
      
 1445 
     | 
    
         
            +
                    get_mla_kv_buffer_triton(kv_buffer, loc, cache_k_nope, cache_k_rope)
         
     | 
| 
      
 1446 
     | 
    
         
            +
                    return cache_k_nope, cache_k_rope
         
     | 
| 
      
 1447 
     | 
    
         
            +
             
     | 
| 
       1366 
1448 
     | 
    
         
             
                def get_cpu_copy(self, indices):
         
     | 
| 
       1367 
1449 
     | 
    
         
             
                    torch.cuda.synchronize()
         
     | 
| 
       1368 
1450 
     | 
    
         
             
                    kv_cache_cpu = []
         
     | 
| 
         @@ -3,8 +3,9 @@ import atexit 
     | 
|
| 
       3 
3 
     | 
    
         
             
            import json
         
     | 
| 
       4 
4 
     | 
    
         
             
            import logging
         
     | 
| 
       5 
5 
     | 
    
         
             
            import threading
         
     | 
| 
      
 6 
     | 
    
         
            +
            from collections import OrderedDict
         
     | 
| 
       6 
7 
     | 
    
         
             
            from pathlib import Path
         
     | 
| 
       7 
     | 
    
         
            -
            from typing import Dict, List, Optional,  
     | 
| 
      
 8 
     | 
    
         
            +
            from typing import Dict, List, Optional, Tuple
         
     | 
| 
       8 
9 
     | 
    
         | 
| 
       9 
10 
     | 
    
         
             
            import orjson
         
     | 
| 
       10 
11 
     | 
    
         
             
            import requests
         
     | 
| 
         @@ -136,7 +137,7 @@ class GlobalMetadataState: 
     | 
|
| 
       136 
137 
     | 
    
         
             
                                num_pages = data["num_pages"]
         
     | 
| 
       137 
138 
     | 
    
         
             
                                rank_meta = RankMetadata(num_pages)
         
     | 
| 
       138 
139 
     | 
    
         
             
                                rank_meta.free_pages = data["free_pages"]
         
     | 
| 
       139 
     | 
    
         
            -
                                rank_meta.key_to_index =  
     | 
| 
      
 140 
     | 
    
         
            +
                                rank_meta.key_to_index = OrderedDict(data["key_to_index"])
         
     | 
| 
       140 
141 
     | 
    
         
             
                                self.ranks[rank_id] = rank_meta
         
     | 
| 
       141 
142 
     | 
    
         
             
                            logging.info(
         
     | 
| 
       142 
143 
     | 
    
         
             
                                f"Successfully loaded metadata for {len(self.ranks)} ranks."
         
     | 
| 
         @@ -39,6 +39,7 @@ import triton 
     | 
|
| 
       39 
39 
     | 
    
         
             
            import triton.language as tl
         
     | 
| 
       40 
40 
     | 
    
         | 
| 
       41 
41 
     | 
    
         
             
            from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
         
     | 
| 
      
 42 
     | 
    
         
            +
            from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
         
     | 
| 
       42 
43 
     | 
    
         
             
            from sglang.srt.layers.dp_attention import (
         
     | 
| 
       43 
44 
     | 
    
         
             
                DpPaddingMode,
         
     | 
| 
       44 
45 
     | 
    
         
             
                get_attention_dp_rank,
         
     | 
| 
         @@ -250,6 +251,8 @@ class ForwardBatch: 
     | 
|
| 
       250 
251 
     | 
    
         
             
                # For MLA chunked prefix cache used in chunked prefill
         
     | 
| 
       251 
252 
     | 
    
         
             
                # Tell attention backend whether lse needs to be returned
         
     | 
| 
       252 
253 
     | 
    
         
             
                mha_return_lse: Optional[bool] = None
         
     | 
| 
      
 254 
     | 
    
         
            +
                mha_one_shot_kv_indices: Optional[torch.Tensor] = None
         
     | 
| 
      
 255 
     | 
    
         
            +
                mha_one_shot: Optional[bool] = None
         
     | 
| 
       253 
256 
     | 
    
         | 
| 
       254 
257 
     | 
    
         
             
                # For multimodal
         
     | 
| 
       255 
258 
     | 
    
         
             
                mm_inputs: Optional[List[MultimodalInputs]] = None
         
     | 
| 
         @@ -572,9 +575,15 @@ class ForwardBatch: 
     | 
|
| 
       572 
575 
     | 
    
         
             
                                    device=model_runner.device,
         
     | 
| 
       573 
576 
     | 
    
         
             
                                )
         
     | 
| 
       574 
577 
     | 
    
         
             
                            else:
         
     | 
| 
       575 
     | 
    
         
            -
                                 
     | 
| 
       576 
     | 
    
         
            -
                                     
     | 
| 
       577 
     | 
    
         
            -
             
     | 
| 
      
 578 
     | 
    
         
            +
                                if mm_input.mrope_position_delta.device.type != model_runner.device:
         
     | 
| 
      
 579 
     | 
    
         
            +
                                    # transfer mrope_position_delta to device when the first running,
         
     | 
| 
      
 580 
     | 
    
         
            +
                                    # avoiding successvie host-to-device data transfer
         
     | 
| 
      
 581 
     | 
    
         
            +
                                    mm_input.mrope_position_delta = (
         
     | 
| 
      
 582 
     | 
    
         
            +
                                        mm_input.mrope_position_delta.to(
         
     | 
| 
      
 583 
     | 
    
         
            +
                                            model_runner.device, non_blocking=True
         
     | 
| 
      
 584 
     | 
    
         
            +
                                        )
         
     | 
| 
      
 585 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 586 
     | 
    
         
            +
                                mrope_position_deltas = mm_input.mrope_position_delta.flatten()
         
     | 
| 
       578 
587 
     | 
    
         
             
                                mrope_positions_list[batch_idx] = (
         
     | 
| 
       579 
588 
     | 
    
         
             
                                    (mrope_position_deltas + self.seq_lens[batch_idx] - 1)
         
     | 
| 
       580 
589 
     | 
    
         
             
                                    .unsqueeze(0)
         
     | 
| 
         @@ -863,6 +872,10 @@ class ForwardBatch: 
     | 
|
| 
       863 
872 
     | 
    
         
             
                        self.token_to_kv_pool, MLATokenToKVPool
         
     | 
| 
       864 
873 
     | 
    
         
             
                    ), "Currently chunked prefix cache can only be used by Deepseek models"
         
     | 
| 
       865 
874 
     | 
    
         | 
| 
      
 875 
     | 
    
         
            +
                    if not any(self.extend_prefix_lens_cpu):
         
     | 
| 
      
 876 
     | 
    
         
            +
                        self.num_prefix_chunks = 0
         
     | 
| 
      
 877 
     | 
    
         
            +
                        return
         
     | 
| 
      
 878 
     | 
    
         
            +
             
     | 
| 
       866 
879 
     | 
    
         
             
                    if self.prefix_chunk_len is not None:
         
     | 
| 
       867 
880 
     | 
    
         
             
                        # Chunked kv cache info already prepared by prior modules
         
     | 
| 
       868 
881 
     | 
    
         
             
                        return
         
     | 
| 
         @@ -917,6 +930,34 @@ class ForwardBatch: 
     | 
|
| 
       917 
930 
     | 
    
         
             
                def can_run_tbo(self):
         
     | 
| 
       918 
931 
     | 
    
         
             
                    return self.tbo_split_seq_index is not None
         
     | 
| 
       919 
932 
     | 
    
         | 
| 
      
 933 
     | 
    
         
            +
                def fetch_mha_one_shot_kv_indices(self):
         
     | 
| 
      
 934 
     | 
    
         
            +
                    if self.mha_one_shot_kv_indices is not None:
         
     | 
| 
      
 935 
     | 
    
         
            +
                        return self.mha_one_shot_kv_indices
         
     | 
| 
      
 936 
     | 
    
         
            +
                    batch_size = self.batch_size
         
     | 
| 
      
 937 
     | 
    
         
            +
                    paged_kernel_lens_sum = sum(self.seq_lens_cpu)
         
     | 
| 
      
 938 
     | 
    
         
            +
                    kv_indices = torch.empty(
         
     | 
| 
      
 939 
     | 
    
         
            +
                        paged_kernel_lens_sum,
         
     | 
| 
      
 940 
     | 
    
         
            +
                        dtype=torch.int32,
         
     | 
| 
      
 941 
     | 
    
         
            +
                        device=self.req_pool_indices.device,
         
     | 
| 
      
 942 
     | 
    
         
            +
                    )
         
     | 
| 
      
 943 
     | 
    
         
            +
                    kv_indptr = torch.zeros(
         
     | 
| 
      
 944 
     | 
    
         
            +
                        batch_size + 1,
         
     | 
| 
      
 945 
     | 
    
         
            +
                        dtype=torch.int32,
         
     | 
| 
      
 946 
     | 
    
         
            +
                        device=self.req_pool_indices.device,
         
     | 
| 
      
 947 
     | 
    
         
            +
                    )
         
     | 
| 
      
 948 
     | 
    
         
            +
                    kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
         
     | 
| 
      
 949 
     | 
    
         
            +
                    create_flashinfer_kv_indices_triton[(self.batch_size,)](
         
     | 
| 
      
 950 
     | 
    
         
            +
                        self.req_to_token_pool.req_to_token,
         
     | 
| 
      
 951 
     | 
    
         
            +
                        self.req_pool_indices,
         
     | 
| 
      
 952 
     | 
    
         
            +
                        self.seq_lens,
         
     | 
| 
      
 953 
     | 
    
         
            +
                        kv_indptr,
         
     | 
| 
      
 954 
     | 
    
         
            +
                        None,
         
     | 
| 
      
 955 
     | 
    
         
            +
                        kv_indices,
         
     | 
| 
      
 956 
     | 
    
         
            +
                        self.req_to_token_pool.req_to_token.shape[1],
         
     | 
| 
      
 957 
     | 
    
         
            +
                    )
         
     | 
| 
      
 958 
     | 
    
         
            +
                    self.mha_one_shot_kv_indices = kv_indices
         
     | 
| 
      
 959 
     | 
    
         
            +
                    return kv_indices
         
     | 
| 
      
 960 
     | 
    
         
            +
             
     | 
| 
       920 
961 
     | 
    
         | 
| 
       921 
962 
     | 
    
         
             
            def enable_num_token_non_padded(server_args):
         
     | 
| 
       922 
963 
     | 
    
         
             
                return get_moe_expert_parallel_world_size() > 1
         
     | 
| 
         @@ -131,16 +131,10 @@ from sglang.srt.utils import ( 
     | 
|
| 
       131 
131 
     | 
    
         
             
                get_bool_env_var,
         
     | 
| 
       132 
132 
     | 
    
         
             
                get_cpu_ids_by_node,
         
     | 
| 
       133 
133 
     | 
    
         
             
                init_custom_process_group,
         
     | 
| 
       134 
     | 
    
         
            -
                is_fa3_default_architecture,
         
     | 
| 
       135 
     | 
    
         
            -
                is_flashinfer_available,
         
     | 
| 
       136 
134 
     | 
    
         
             
                is_hip,
         
     | 
| 
       137 
     | 
    
         
            -
                is_hopper_with_cuda_12_3,
         
     | 
| 
       138 
     | 
    
         
            -
                is_no_spec_infer_or_topk_one,
         
     | 
| 
       139 
135 
     | 
    
         
             
                is_npu,
         
     | 
| 
       140 
     | 
    
         
            -
                is_sm100_supported,
         
     | 
| 
       141 
136 
     | 
    
         
             
                log_info_on_rank0,
         
     | 
| 
       142 
137 
     | 
    
         
             
                monkey_patch_p2p_access_check,
         
     | 
| 
       143 
     | 
    
         
            -
                monkey_patch_vllm_gguf_config,
         
     | 
| 
       144 
138 
     | 
    
         
             
                set_cuda_arch,
         
     | 
| 
       145 
139 
     | 
    
         
             
                slow_rank_detector,
         
     | 
| 
       146 
140 
     | 
    
         
             
                xpu_has_xmx_support,
         
     | 
| 
         @@ -503,121 +497,6 @@ class ModelRunner: 
     | 
|
| 
       503 
497 
     | 
    
         
             
                def model_specific_adjustment(self):
         
     | 
| 
       504 
498 
     | 
    
         
             
                    server_args = self.server_args
         
     | 
| 
       505 
499 
     | 
    
         | 
| 
       506 
     | 
    
         
            -
                    if (
         
     | 
| 
       507 
     | 
    
         
            -
                        server_args.attention_backend == "intel_amx"
         
     | 
| 
       508 
     | 
    
         
            -
                        and server_args.device == "cpu"
         
     | 
| 
       509 
     | 
    
         
            -
                        and not _is_cpu_amx_available
         
     | 
| 
       510 
     | 
    
         
            -
                    ):
         
     | 
| 
       511 
     | 
    
         
            -
                        logger.info(
         
     | 
| 
       512 
     | 
    
         
            -
                            "The current platform does not support Intel AMX, will fallback to torch_native backend."
         
     | 
| 
       513 
     | 
    
         
            -
                        )
         
     | 
| 
       514 
     | 
    
         
            -
                        server_args.attention_backend = "torch_native"
         
     | 
| 
       515 
     | 
    
         
            -
             
     | 
| 
       516 
     | 
    
         
            -
                    if (
         
     | 
| 
       517 
     | 
    
         
            -
                        server_args.attention_backend == "intel_xpu"
         
     | 
| 
       518 
     | 
    
         
            -
                        and server_args.device == "xpu"
         
     | 
| 
       519 
     | 
    
         
            -
                        and not _is_xpu_xmx_available
         
     | 
| 
       520 
     | 
    
         
            -
                    ):
         
     | 
| 
       521 
     | 
    
         
            -
                        logger.info(
         
     | 
| 
       522 
     | 
    
         
            -
                            "The current platform does not support Intel XMX, will fallback to triton backend."
         
     | 
| 
       523 
     | 
    
         
            -
                        )
         
     | 
| 
       524 
     | 
    
         
            -
                        server_args.attention_backend = "triton"
         
     | 
| 
       525 
     | 
    
         
            -
             
     | 
| 
       526 
     | 
    
         
            -
                    if server_args.prefill_attention_backend is not None and (
         
     | 
| 
       527 
     | 
    
         
            -
                        server_args.prefill_attention_backend
         
     | 
| 
       528 
     | 
    
         
            -
                        == server_args.decode_attention_backend
         
     | 
| 
       529 
     | 
    
         
            -
                    ):  # override the default attention backend
         
     | 
| 
       530 
     | 
    
         
            -
                        server_args.attention_backend = server_args.prefill_attention_backend
         
     | 
| 
       531 
     | 
    
         
            -
             
     | 
| 
       532 
     | 
    
         
            -
                    if (
         
     | 
| 
       533 
     | 
    
         
            -
                        getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
         
     | 
| 
       534 
     | 
    
         
            -
                        is not None
         
     | 
| 
       535 
     | 
    
         
            -
                    ):
         
     | 
| 
       536 
     | 
    
         
            -
                        if server_args.attention_backend is None:
         
     | 
| 
       537 
     | 
    
         
            -
                            server_args.attention_backend = "dual_chunk_flash_attn"
         
     | 
| 
       538 
     | 
    
         
            -
                            logger.info("Dual chunk attention is turned on by default.")
         
     | 
| 
       539 
     | 
    
         
            -
                        elif server_args.attention_backend != "dual_chunk_flash_attn":
         
     | 
| 
       540 
     | 
    
         
            -
                            raise ValueError(
         
     | 
| 
       541 
     | 
    
         
            -
                                "Dual chunk attention is enabled, but attention backend is set to "
         
     | 
| 
       542 
     | 
    
         
            -
                                f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
         
     | 
| 
       543 
     | 
    
         
            -
                            )
         
     | 
| 
       544 
     | 
    
         
            -
             
     | 
| 
       545 
     | 
    
         
            -
                    if server_args.attention_backend is None:
         
     | 
| 
       546 
     | 
    
         
            -
                        """
         
     | 
| 
       547 
     | 
    
         
            -
                        Auto select the fastest attention backend.
         
     | 
| 
       548 
     | 
    
         
            -
             
     | 
| 
       549 
     | 
    
         
            -
                        1. Models with MHA Architecture (e.g: Llama, QWen)
         
     | 
| 
       550 
     | 
    
         
            -
                            1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
         
     | 
| 
       551 
     | 
    
         
            -
                            1.2 In other cases, we will use flashinfer if available, otherwise use triton.
         
     | 
| 
       552 
     | 
    
         
            -
                        2. Models with MLA Architecture and using FA3
         
     | 
| 
       553 
     | 
    
         
            -
                            2.1 We will use FA3 backend on hopper.
         
     | 
| 
       554 
     | 
    
         
            -
                            2.2 We will use Flashinfer backend on blackwell.
         
     | 
| 
       555 
     | 
    
         
            -
                            2.3 Otherwise, we will use triton backend.
         
     | 
| 
       556 
     | 
    
         
            -
                        """
         
     | 
| 
       557 
     | 
    
         
            -
             
     | 
| 
       558 
     | 
    
         
            -
                        if not self.use_mla_backend:
         
     | 
| 
       559 
     | 
    
         
            -
                            # MHA architecture
         
     | 
| 
       560 
     | 
    
         
            -
                            if (
         
     | 
| 
       561 
     | 
    
         
            -
                                is_hopper_with_cuda_12_3()
         
     | 
| 
       562 
     | 
    
         
            -
                                and is_no_spec_infer_or_topk_one(server_args)
         
     | 
| 
       563 
     | 
    
         
            -
                                and is_fa3_default_architecture(self.model_config.hf_config)
         
     | 
| 
       564 
     | 
    
         
            -
                            ):
         
     | 
| 
       565 
     | 
    
         
            -
                                server_args.attention_backend = "fa3"
         
     | 
| 
       566 
     | 
    
         
            -
                            elif _is_hip:
         
     | 
| 
       567 
     | 
    
         
            -
                                server_args.attention_backend = "aiter"
         
     | 
| 
       568 
     | 
    
         
            -
                            elif _is_npu:
         
     | 
| 
       569 
     | 
    
         
            -
                                server_args.attention_backend = "ascend"
         
     | 
| 
       570 
     | 
    
         
            -
                            else:
         
     | 
| 
       571 
     | 
    
         
            -
                                server_args.attention_backend = (
         
     | 
| 
       572 
     | 
    
         
            -
                                    "flashinfer" if is_flashinfer_available() else "triton"
         
     | 
| 
       573 
     | 
    
         
            -
                                )
         
     | 
| 
       574 
     | 
    
         
            -
                        else:
         
     | 
| 
       575 
     | 
    
         
            -
                            # MLA architecture
         
     | 
| 
       576 
     | 
    
         
            -
                            if is_hopper_with_cuda_12_3():
         
     | 
| 
       577 
     | 
    
         
            -
                                server_args.attention_backend = "fa3"
         
     | 
| 
       578 
     | 
    
         
            -
                            elif is_sm100_supported():
         
     | 
| 
       579 
     | 
    
         
            -
                                server_args.attention_backend = "flashinfer"
         
     | 
| 
       580 
     | 
    
         
            -
                            elif _is_hip:
         
     | 
| 
       581 
     | 
    
         
            -
                                head_num = self.model_config.get_num_kv_heads(self.tp_size)
         
     | 
| 
       582 
     | 
    
         
            -
                                # TODO current aiter only support head number 16 or 128 head number
         
     | 
| 
       583 
     | 
    
         
            -
                                if head_num == 128 or head_num == 16:
         
     | 
| 
       584 
     | 
    
         
            -
                                    server_args.attention_backend = "aiter"
         
     | 
| 
       585 
     | 
    
         
            -
                                else:
         
     | 
| 
       586 
     | 
    
         
            -
                                    server_args.attention_backend = "triton"
         
     | 
| 
       587 
     | 
    
         
            -
                            elif _is_npu:
         
     | 
| 
       588 
     | 
    
         
            -
                                server_args.attention_backend = "ascend"
         
     | 
| 
       589 
     | 
    
         
            -
                            else:
         
     | 
| 
       590 
     | 
    
         
            -
                                server_args.attention_backend = "triton"
         
     | 
| 
       591 
     | 
    
         
            -
                        log_info_on_rank0(
         
     | 
| 
       592 
     | 
    
         
            -
                            logger,
         
     | 
| 
       593 
     | 
    
         
            -
                            f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
         
     | 
| 
       594 
     | 
    
         
            -
                        )
         
     | 
| 
       595 
     | 
    
         
            -
                    elif self.use_mla_backend:
         
     | 
| 
       596 
     | 
    
         
            -
                        if server_args.device != "cpu":
         
     | 
| 
       597 
     | 
    
         
            -
                            if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
         
     | 
| 
       598 
     | 
    
         
            -
                                logger.info(
         
     | 
| 
       599 
     | 
    
         
            -
                                    f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
         
     | 
| 
       600 
     | 
    
         
            -
                                )
         
     | 
| 
       601 
     | 
    
         
            -
                            else:
         
     | 
| 
       602 
     | 
    
         
            -
                                raise ValueError(
         
     | 
| 
       603 
     | 
    
         
            -
                                    f"Invalid attention backend for MLA: {server_args.attention_backend}"
         
     | 
| 
       604 
     | 
    
         
            -
                                )
         
     | 
| 
       605 
     | 
    
         
            -
                        else:
         
     | 
| 
       606 
     | 
    
         
            -
                            if server_args.attention_backend != "intel_amx":
         
     | 
| 
       607 
     | 
    
         
            -
                                raise ValueError(
         
     | 
| 
       608 
     | 
    
         
            -
                                    "MLA optimization not supported on CPU except for intel_amx backend."
         
     | 
| 
       609 
     | 
    
         
            -
                                )
         
     | 
| 
       610 
     | 
    
         
            -
             
     | 
| 
       611 
     | 
    
         
            -
                    if (
         
     | 
| 
       612 
     | 
    
         
            -
                        server_args.attention_backend == "fa3"
         
     | 
| 
       613 
     | 
    
         
            -
                        and server_args.kv_cache_dtype == "fp8_e5m2"
         
     | 
| 
       614 
     | 
    
         
            -
                    ):
         
     | 
| 
       615 
     | 
    
         
            -
                        logger.warning(
         
     | 
| 
       616 
     | 
    
         
            -
                            "FlashAttention3 only supports fp8_e4m3 if using FP8; "
         
     | 
| 
       617 
     | 
    
         
            -
                            "Setting attention backend to triton."
         
     | 
| 
       618 
     | 
    
         
            -
                        )
         
     | 
| 
       619 
     | 
    
         
            -
                        server_args.attention_backend = "triton"
         
     | 
| 
       620 
     | 
    
         
            -
             
     | 
| 
       621 
500 
     | 
    
         
             
                    if server_args.enable_double_sparsity:
         
     | 
| 
       622 
501 
     | 
    
         
             
                        logger.info(
         
     | 
| 
       623 
502 
     | 
    
         
             
                            "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
         
     | 
| 
         @@ -643,37 +522,12 @@ class ModelRunner: 
     | 
|
| 
       643 
522 
     | 
    
         
             
                    if not server_args.disable_chunked_prefix_cache:
         
     | 
| 
       644 
523 
     | 
    
         
             
                        log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
         
     | 
| 
       645 
524 
     | 
    
         | 
| 
       646 
     | 
    
         
            -
                    if server_args.attention_backend == "aiter":
         
     | 
| 
       647 
     | 
    
         
            -
                        if self.model_config.context_len > 8192:
         
     | 
| 
       648 
     | 
    
         
            -
                            self.mem_fraction_static *= 0.85
         
     | 
| 
       649 
     | 
    
         
            -
             
     | 
| 
       650 
     | 
    
         
            -
                    if (
         
     | 
| 
       651 
     | 
    
         
            -
                        server_args.enable_hierarchical_cache
         
     | 
| 
       652 
     | 
    
         
            -
                        and server_args.hicache_io_backend == "kernel"
         
     | 
| 
       653 
     | 
    
         
            -
                    ):
         
     | 
| 
       654 
     | 
    
         
            -
                        # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
         
     | 
| 
       655 
     | 
    
         
            -
                        if server_args.decode_attention_backend is None:
         
     | 
| 
       656 
     | 
    
         
            -
                            if not self.use_mla_backend:
         
     | 
| 
       657 
     | 
    
         
            -
                                server_args.decode_attention_backend = (
         
     | 
| 
       658 
     | 
    
         
            -
                                    "flashinfer" if is_flashinfer_available() else "triton"
         
     | 
| 
       659 
     | 
    
         
            -
                                )
         
     | 
| 
       660 
     | 
    
         
            -
                            else:
         
     | 
| 
       661 
     | 
    
         
            -
                                server_args.decode_attention_backend = (
         
     | 
| 
       662 
     | 
    
         
            -
                                    "flashinfer" if is_sm100_supported() else "triton"
         
     | 
| 
       663 
     | 
    
         
            -
                                )
         
     | 
| 
       664 
     | 
    
         
            -
                        elif server_args.decode_attention_backend == "fa3":
         
     | 
| 
       665 
     | 
    
         
            -
                            server_args.hicache_io_backend = "direct"
         
     | 
| 
       666 
     | 
    
         
            -
                            logger.warning(
         
     | 
| 
       667 
     | 
    
         
            -
                                "FlashAttention3 decode backend is not compatible with hierarchical cache. "
         
     | 
| 
       668 
     | 
    
         
            -
                                "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
         
     | 
| 
       669 
     | 
    
         
            -
                            )
         
     | 
| 
       670 
     | 
    
         
            -
             
     | 
| 
       671 
525 
     | 
    
         
             
                    if self.model_config.hf_config.model_type == "qwen3_vl_moe":
         
     | 
| 
       672 
526 
     | 
    
         
             
                        if (
         
     | 
| 
       673 
527 
     | 
    
         
             
                            quantization_config := getattr(
         
     | 
| 
       674 
528 
     | 
    
         
             
                                self.model_config.hf_config, "quantization_config", None
         
     | 
| 
       675 
529 
     | 
    
         
             
                            )
         
     | 
| 
       676 
     | 
    
         
            -
                        ) is not None:
         
     | 
| 
      
 530 
     | 
    
         
            +
                        ) is not None and "weight_block_size" in quantization_config:
         
     | 
| 
       677 
531 
     | 
    
         
             
                            weight_block_size_n = quantization_config["weight_block_size"][0]
         
     | 
| 
       678 
532 
     | 
    
         | 
| 
       679 
533 
     | 
    
         
             
                            if self.tp_size % self.moe_ep_size != 0:
         
     | 
| 
         @@ -858,8 +712,6 @@ class ModelRunner: 
     | 
|
| 
       858 
712 
     | 
    
         
             
                        self.model_config = adjust_config_with_unaligned_cpu_tp(
         
     | 
| 
       859 
713 
     | 
    
         
             
                            self.model_config, self.load_config, self.tp_size
         
     | 
| 
       860 
714 
     | 
    
         
             
                        )
         
     | 
| 
       861 
     | 
    
         
            -
                    if self.server_args.load_format == "gguf":
         
     | 
| 
       862 
     | 
    
         
            -
                        monkey_patch_vllm_gguf_config()
         
     | 
| 
       863 
715 
     | 
    
         | 
| 
       864 
716 
     | 
    
         
             
                    if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
         
     | 
| 
       865 
717 
     | 
    
         
             
                        if self.tp_rank == 0:
         
     | 
| 
         @@ -32,7 +32,6 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank 
     | 
|
| 
       32 
32 
     | 
    
         
             
            from sglang.srt.distributed.device_communicators.pynccl_allocator import (
         
     | 
| 
       33 
33 
     | 
    
         
             
                set_graph_pool_id,
         
     | 
| 
       34 
34 
     | 
    
         
             
            )
         
     | 
| 
       35 
     | 
    
         
            -
            from sglang.srt.distributed.parallel_state import graph_capture
         
     | 
| 
       36 
35 
     | 
    
         
             
            from sglang.srt.layers.dp_attention import (
         
     | 
| 
       37 
36 
     | 
    
         
             
                DpPaddingMode,
         
     | 
| 
       38 
37 
     | 
    
         
             
                get_attention_tp_rank,
         
     | 
| 
         @@ -250,6 +249,9 @@ class PiecewiseCudaGraphRunner: 
     | 
|
| 
       250 
249 
     | 
    
         
             
                            lora_ids=None,
         
     | 
| 
       251 
250 
     | 
    
         
             
                        )
         
     | 
| 
       252 
251 
     | 
    
         | 
| 
      
 252 
     | 
    
         
            +
                    # Attention backend
         
     | 
| 
      
 253 
     | 
    
         
            +
                    self.model_runner.attn_backend.init_forward_metadata(forward_batch)
         
     | 
| 
      
 254 
     | 
    
         
            +
             
     | 
| 
       253 
255 
     | 
    
         
             
                    with set_forward_context(forward_batch, self.attention_layers):
         
     | 
| 
       254 
256 
     | 
    
         
             
                        _ = self.model_runner.model.forward(
         
     | 
| 
       255 
257 
     | 
    
         
             
                            forward_batch.input_ids,
         
     | 
| 
         @@ -262,9 +264,14 @@ class PiecewiseCudaGraphRunner: 
     | 
|
| 
       262 
264 
     | 
    
         | 
| 
       263 
265 
     | 
    
         
             
                def can_run(self, forward_batch: ForwardBatch):
         
     | 
| 
       264 
266 
     | 
    
         
             
                    num_tokens = len(forward_batch.input_ids)
         
     | 
| 
       265 
     | 
    
         
            -
                    # TODO(yuwei): support return logprob
         
     | 
| 
      
 267 
     | 
    
         
            +
                    # TODO(yuwei): support return input_ids' logprob
         
     | 
| 
       266 
268 
     | 
    
         
             
                    if forward_batch.return_logprob:
         
     | 
| 
       267 
     | 
    
         
            -
                         
     | 
| 
      
 269 
     | 
    
         
            +
                        for start_len, seq_len in zip(
         
     | 
| 
      
 270 
     | 
    
         
            +
                            forward_batch.extend_logprob_start_lens_cpu,
         
     | 
| 
      
 271 
     | 
    
         
            +
                            forward_batch.extend_seq_lens_cpu,
         
     | 
| 
      
 272 
     | 
    
         
            +
                        ):
         
     | 
| 
      
 273 
     | 
    
         
            +
                            if start_len is not None and start_len < seq_len:
         
     | 
| 
      
 274 
     | 
    
         
            +
                                return False
         
     | 
| 
       268 
275 
     | 
    
         
             
                    if num_tokens <= self.max_num_tokens:
         
     | 
| 
       269 
276 
     | 
    
         
             
                        return True
         
     | 
| 
       270 
277 
     | 
    
         
             
                    return False
         
     | 
| 
         @@ -273,10 +280,10 @@ class PiecewiseCudaGraphRunner: 
     | 
|
| 
       273 
280 
     | 
    
         
             
                    # Trigger CUDA graph capture for specific shapes.
         
     | 
| 
       274 
281 
     | 
    
         
             
                    # Capture the large shapes first so that the smaller shapes
         
     | 
| 
       275 
282 
     | 
    
         
             
                    # can reuse the memory pool allocated for the large shapes.
         
     | 
| 
       276 
     | 
    
         
            -
                    with freeze_gc(
         
     | 
| 
       277 
     | 
    
         
            -
                        self.model_runner. 
     | 
| 
       278 
     | 
    
         
            -
             
     | 
| 
       279 
     | 
    
         
            -
             
     | 
| 
      
 283 
     | 
    
         
            +
                    with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
         
     | 
| 
      
 284 
     | 
    
         
            +
                        if self.model_runner.tp_group.ca_comm is not None:
         
     | 
| 
      
 285 
     | 
    
         
            +
                            old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
         
     | 
| 
      
 286 
     | 
    
         
            +
                            self.model_runner.tp_group.ca_comm.disabled = True
         
     | 
| 
       280 
287 
     | 
    
         
             
                        avail_mem = get_available_gpu_memory(
         
     | 
| 
       281 
288 
     | 
    
         
             
                            self.model_runner.device,
         
     | 
| 
       282 
289 
     | 
    
         
             
                            self.model_runner.gpu_id,
         
     | 
| 
         @@ -304,9 +311,10 @@ class PiecewiseCudaGraphRunner: 
     | 
|
| 
       304 
311 
     | 
    
         | 
| 
       305 
312 
     | 
    
         
             
                            # Save gemlite cache after each capture
         
     | 
| 
       306 
313 
     | 
    
         
             
                            save_gemlite_cache()
         
     | 
| 
      
 314 
     | 
    
         
            +
                        if self.model_runner.tp_group.ca_comm is not None:
         
     | 
| 
      
 315 
     | 
    
         
            +
                            self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
         
     | 
| 
       307 
316 
     | 
    
         | 
| 
       308 
317 
     | 
    
         
             
                def capture_one_batch_size(self, num_tokens: int):
         
     | 
| 
       309 
     | 
    
         
            -
                    stream = self.stream
         
     | 
| 
       310 
318 
     | 
    
         
             
                    bs = 1
         
     | 
| 
       311 
319 
     | 
    
         | 
| 
       312 
320 
     | 
    
         
             
                    # Graph inputs
         
     | 
| 
         @@ -370,9 +378,6 @@ class PiecewiseCudaGraphRunner: 
     | 
|
| 
       370 
378 
     | 
    
         
             
                    if lora_ids is not None:
         
     | 
| 
       371 
379 
     | 
    
         
             
                        self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
         
     | 
| 
       372 
380 
     | 
    
         | 
| 
       373 
     | 
    
         
            -
                    # # Attention backend
         
     | 
| 
       374 
     | 
    
         
            -
                    self.model_runner.attn_backend.init_forward_metadata(forward_batch)
         
     | 
| 
       375 
     | 
    
         
            -
             
     | 
| 
       376 
381 
     | 
    
         
             
                    # Run and capture
         
     | 
| 
       377 
382 
     | 
    
         
             
                    def run_once():
         
     | 
| 
       378 
383 
     | 
    
         
             
                        # Clean intermediate result cache for DP attention
         
     | 
| 
         @@ -438,7 +443,7 @@ class PiecewiseCudaGraphRunner: 
     | 
|
| 
       438 
443 
     | 
    
         
             
                        out_cache_loc=out_cache_loc,
         
     | 
| 
       439 
444 
     | 
    
         
             
                        seq_lens_sum=forward_batch.seq_lens_sum,
         
     | 
| 
       440 
445 
     | 
    
         
             
                        encoder_lens=forward_batch.encoder_lens,
         
     | 
| 
       441 
     | 
    
         
            -
                        return_logprob= 
     | 
| 
      
 446 
     | 
    
         
            +
                        return_logprob=False,
         
     | 
| 
       442 
447 
     | 
    
         
             
                        extend_seq_lens=forward_batch.extend_seq_lens,
         
     | 
| 
       443 
448 
     | 
    
         
             
                        extend_prefix_lens=forward_batch.extend_prefix_lens,
         
     | 
| 
       444 
449 
     | 
    
         
             
                        extend_start_loc=forward_batch.extend_start_loc,
         
     | 
| 
         @@ -474,6 +479,9 @@ class PiecewiseCudaGraphRunner: 
     | 
|
| 
       474 
479 
     | 
    
         
             
                    forward_batch: ForwardBatch,
         
     | 
| 
       475 
480 
     | 
    
         
             
                    **kwargs,
         
     | 
| 
       476 
481 
     | 
    
         
             
                ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
         
     | 
| 
      
 482 
     | 
    
         
            +
                    if self.model_runner.tp_group.ca_comm is not None:
         
     | 
| 
      
 483 
     | 
    
         
            +
                        old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
         
     | 
| 
      
 484 
     | 
    
         
            +
                        self.model_runner.tp_group.ca_comm.disabled = True
         
     | 
| 
       477 
485 
     | 
    
         
             
                    static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
         
     | 
| 
       478 
486 
     | 
    
         
             
                    # Replay
         
     | 
| 
       479 
487 
     | 
    
         
             
                    with set_forward_context(static_forward_batch, self.attention_layers):
         
     | 
| 
         @@ -499,6 +507,8 @@ class PiecewiseCudaGraphRunner: 
     | 
|
| 
       499 
507 
     | 
    
         
             
                            raise NotImplementedError(
         
     | 
| 
       500 
508 
     | 
    
         
             
                                "PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
         
     | 
| 
       501 
509 
     | 
    
         
             
                            )
         
     | 
| 
      
 510 
     | 
    
         
            +
                    if self.model_runner.tp_group.ca_comm is not None:
         
     | 
| 
      
 511 
     | 
    
         
            +
                        self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
         
     | 
| 
       502 
512 
     | 
    
         | 
| 
       503 
513 
     | 
    
         
             
                def get_spec_info(self, num_tokens: int):
         
     | 
| 
       504 
514 
     | 
    
         
             
                    spec_info = None
         
     |