sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1213,6 +1213,65 @@ def set_mla_kv_buffer_triton(
1213
1213
  )
1214
1214
 
1215
1215
 
1216
+ @triton.jit
1217
+ def get_mla_kv_buffer_kernel(
1218
+ kv_buffer_ptr,
1219
+ cache_k_nope_ptr,
1220
+ cache_k_rope_ptr,
1221
+ loc_ptr,
1222
+ buffer_stride: tl.constexpr,
1223
+ nope_stride: tl.constexpr,
1224
+ rope_stride: tl.constexpr,
1225
+ nope_dim: tl.constexpr,
1226
+ rope_dim: tl.constexpr,
1227
+ ):
1228
+ pid_loc = tl.program_id(0)
1229
+ loc = tl.load(loc_ptr + pid_loc)
1230
+ loc_src_ptr = kv_buffer_ptr + loc * buffer_stride
1231
+
1232
+ nope_offs = tl.arange(0, nope_dim)
1233
+ nope_src_ptr = loc_src_ptr + nope_offs
1234
+ nope_src = tl.load(nope_src_ptr)
1235
+
1236
+ tl.store(
1237
+ cache_k_nope_ptr + pid_loc * nope_stride + nope_offs,
1238
+ nope_src,
1239
+ )
1240
+
1241
+ rope_offs = tl.arange(0, rope_dim)
1242
+ rope_src_ptr = loc_src_ptr + nope_dim + rope_offs
1243
+ rope_src = tl.load(rope_src_ptr)
1244
+ tl.store(
1245
+ cache_k_rope_ptr + pid_loc * rope_stride + rope_offs,
1246
+ rope_src,
1247
+ )
1248
+
1249
+
1250
+ def get_mla_kv_buffer_triton(
1251
+ kv_buffer: torch.Tensor,
1252
+ loc: torch.Tensor,
1253
+ cache_k_nope: torch.Tensor,
1254
+ cache_k_rope: torch.Tensor,
1255
+ ):
1256
+ # The source data type will be implicitly converted to the target data type.
1257
+ nope_dim = cache_k_nope.shape[-1] # 512
1258
+ rope_dim = cache_k_rope.shape[-1] # 64
1259
+ n_loc = loc.numel()
1260
+ grid = (n_loc,)
1261
+
1262
+ get_mla_kv_buffer_kernel[grid](
1263
+ kv_buffer,
1264
+ cache_k_nope,
1265
+ cache_k_rope,
1266
+ loc,
1267
+ kv_buffer.stride(0),
1268
+ cache_k_nope.stride(0),
1269
+ cache_k_rope.stride(0),
1270
+ nope_dim,
1271
+ rope_dim,
1272
+ )
1273
+
1274
+
1216
1275
  class MLATokenToKVPool(KVCache):
1217
1276
  def __init__(
1218
1277
  self,
@@ -1363,6 +1422,29 @@ class MLATokenToKVPool(KVCache):
1363
1422
  cache_k_rope,
1364
1423
  )
1365
1424
 
1425
+ def get_mla_kv_buffer(
1426
+ self,
1427
+ layer: RadixAttention,
1428
+ loc: torch.Tensor,
1429
+ dst_dtype: Optional[torch.dtype] = None,
1430
+ ):
1431
+ # get k nope and k rope from the kv buffer, and optionally cast them to dst_dtype.
1432
+ layer_id = layer.layer_id
1433
+ kv_buffer = self.get_key_buffer(layer_id)
1434
+ dst_dtype = dst_dtype or self.dtype
1435
+ cache_k_nope = torch.empty(
1436
+ (loc.shape[0], 1, self.kv_lora_rank),
1437
+ dtype=dst_dtype,
1438
+ device=kv_buffer.device,
1439
+ )
1440
+ cache_k_rope = torch.empty(
1441
+ (loc.shape[0], 1, self.qk_rope_head_dim),
1442
+ dtype=dst_dtype,
1443
+ device=kv_buffer.device,
1444
+ )
1445
+ get_mla_kv_buffer_triton(kv_buffer, loc, cache_k_nope, cache_k_rope)
1446
+ return cache_k_nope, cache_k_rope
1447
+
1366
1448
  def get_cpu_copy(self, indices):
1367
1449
  torch.cuda.synchronize()
1368
1450
  kv_cache_cpu = []
@@ -3,8 +3,9 @@ import atexit
3
3
  import json
4
4
  import logging
5
5
  import threading
6
+ from collections import OrderedDict
6
7
  from pathlib import Path
7
- from typing import Dict, List, Optional, OrderedDict, Tuple
8
+ from typing import Dict, List, Optional, Tuple
8
9
 
9
10
  import orjson
10
11
  import requests
@@ -136,7 +137,7 @@ class GlobalMetadataState:
136
137
  num_pages = data["num_pages"]
137
138
  rank_meta = RankMetadata(num_pages)
138
139
  rank_meta.free_pages = data["free_pages"]
139
- rank_meta.key_to_index = dict(data["key_to_index"])
140
+ rank_meta.key_to_index = OrderedDict(data["key_to_index"])
140
141
  self.ranks[rank_id] = rank_meta
141
142
  logging.info(
142
143
  f"Successfully loaded metadata for {len(self.ranks)} ranks."
@@ -39,6 +39,7 @@ import triton
39
39
  import triton.language as tl
40
40
 
41
41
  from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
42
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
42
43
  from sglang.srt.layers.dp_attention import (
43
44
  DpPaddingMode,
44
45
  get_attention_dp_rank,
@@ -250,6 +251,8 @@ class ForwardBatch:
250
251
  # For MLA chunked prefix cache used in chunked prefill
251
252
  # Tell attention backend whether lse needs to be returned
252
253
  mha_return_lse: Optional[bool] = None
254
+ mha_one_shot_kv_indices: Optional[torch.Tensor] = None
255
+ mha_one_shot: Optional[bool] = None
253
256
 
254
257
  # For multimodal
255
258
  mm_inputs: Optional[List[MultimodalInputs]] = None
@@ -572,9 +575,15 @@ class ForwardBatch:
572
575
  device=model_runner.device,
573
576
  )
574
577
  else:
575
- mrope_position_deltas = mm_input.mrope_position_delta.flatten().to(
576
- model_runner.device, non_blocking=True
577
- )
578
+ if mm_input.mrope_position_delta.device.type != model_runner.device:
579
+ # transfer mrope_position_delta to device when the first running,
580
+ # avoiding successvie host-to-device data transfer
581
+ mm_input.mrope_position_delta = (
582
+ mm_input.mrope_position_delta.to(
583
+ model_runner.device, non_blocking=True
584
+ )
585
+ )
586
+ mrope_position_deltas = mm_input.mrope_position_delta.flatten()
578
587
  mrope_positions_list[batch_idx] = (
579
588
  (mrope_position_deltas + self.seq_lens[batch_idx] - 1)
580
589
  .unsqueeze(0)
@@ -863,6 +872,10 @@ class ForwardBatch:
863
872
  self.token_to_kv_pool, MLATokenToKVPool
864
873
  ), "Currently chunked prefix cache can only be used by Deepseek models"
865
874
 
875
+ if not any(self.extend_prefix_lens_cpu):
876
+ self.num_prefix_chunks = 0
877
+ return
878
+
866
879
  if self.prefix_chunk_len is not None:
867
880
  # Chunked kv cache info already prepared by prior modules
868
881
  return
@@ -917,6 +930,34 @@ class ForwardBatch:
917
930
  def can_run_tbo(self):
918
931
  return self.tbo_split_seq_index is not None
919
932
 
933
+ def fetch_mha_one_shot_kv_indices(self):
934
+ if self.mha_one_shot_kv_indices is not None:
935
+ return self.mha_one_shot_kv_indices
936
+ batch_size = self.batch_size
937
+ paged_kernel_lens_sum = sum(self.seq_lens_cpu)
938
+ kv_indices = torch.empty(
939
+ paged_kernel_lens_sum,
940
+ dtype=torch.int32,
941
+ device=self.req_pool_indices.device,
942
+ )
943
+ kv_indptr = torch.zeros(
944
+ batch_size + 1,
945
+ dtype=torch.int32,
946
+ device=self.req_pool_indices.device,
947
+ )
948
+ kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
949
+ create_flashinfer_kv_indices_triton[(self.batch_size,)](
950
+ self.req_to_token_pool.req_to_token,
951
+ self.req_pool_indices,
952
+ self.seq_lens,
953
+ kv_indptr,
954
+ None,
955
+ kv_indices,
956
+ self.req_to_token_pool.req_to_token.shape[1],
957
+ )
958
+ self.mha_one_shot_kv_indices = kv_indices
959
+ return kv_indices
960
+
920
961
 
921
962
  def enable_num_token_non_padded(server_args):
922
963
  return get_moe_expert_parallel_world_size() > 1
@@ -131,16 +131,10 @@ from sglang.srt.utils import (
131
131
  get_bool_env_var,
132
132
  get_cpu_ids_by_node,
133
133
  init_custom_process_group,
134
- is_fa3_default_architecture,
135
- is_flashinfer_available,
136
134
  is_hip,
137
- is_hopper_with_cuda_12_3,
138
- is_no_spec_infer_or_topk_one,
139
135
  is_npu,
140
- is_sm100_supported,
141
136
  log_info_on_rank0,
142
137
  monkey_patch_p2p_access_check,
143
- monkey_patch_vllm_gguf_config,
144
138
  set_cuda_arch,
145
139
  slow_rank_detector,
146
140
  xpu_has_xmx_support,
@@ -503,121 +497,6 @@ class ModelRunner:
503
497
  def model_specific_adjustment(self):
504
498
  server_args = self.server_args
505
499
 
506
- if (
507
- server_args.attention_backend == "intel_amx"
508
- and server_args.device == "cpu"
509
- and not _is_cpu_amx_available
510
- ):
511
- logger.info(
512
- "The current platform does not support Intel AMX, will fallback to torch_native backend."
513
- )
514
- server_args.attention_backend = "torch_native"
515
-
516
- if (
517
- server_args.attention_backend == "intel_xpu"
518
- and server_args.device == "xpu"
519
- and not _is_xpu_xmx_available
520
- ):
521
- logger.info(
522
- "The current platform does not support Intel XMX, will fallback to triton backend."
523
- )
524
- server_args.attention_backend = "triton"
525
-
526
- if server_args.prefill_attention_backend is not None and (
527
- server_args.prefill_attention_backend
528
- == server_args.decode_attention_backend
529
- ): # override the default attention backend
530
- server_args.attention_backend = server_args.prefill_attention_backend
531
-
532
- if (
533
- getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
534
- is not None
535
- ):
536
- if server_args.attention_backend is None:
537
- server_args.attention_backend = "dual_chunk_flash_attn"
538
- logger.info("Dual chunk attention is turned on by default.")
539
- elif server_args.attention_backend != "dual_chunk_flash_attn":
540
- raise ValueError(
541
- "Dual chunk attention is enabled, but attention backend is set to "
542
- f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
543
- )
544
-
545
- if server_args.attention_backend is None:
546
- """
547
- Auto select the fastest attention backend.
548
-
549
- 1. Models with MHA Architecture (e.g: Llama, QWen)
550
- 1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
551
- 1.2 In other cases, we will use flashinfer if available, otherwise use triton.
552
- 2. Models with MLA Architecture and using FA3
553
- 2.1 We will use FA3 backend on hopper.
554
- 2.2 We will use Flashinfer backend on blackwell.
555
- 2.3 Otherwise, we will use triton backend.
556
- """
557
-
558
- if not self.use_mla_backend:
559
- # MHA architecture
560
- if (
561
- is_hopper_with_cuda_12_3()
562
- and is_no_spec_infer_or_topk_one(server_args)
563
- and is_fa3_default_architecture(self.model_config.hf_config)
564
- ):
565
- server_args.attention_backend = "fa3"
566
- elif _is_hip:
567
- server_args.attention_backend = "aiter"
568
- elif _is_npu:
569
- server_args.attention_backend = "ascend"
570
- else:
571
- server_args.attention_backend = (
572
- "flashinfer" if is_flashinfer_available() else "triton"
573
- )
574
- else:
575
- # MLA architecture
576
- if is_hopper_with_cuda_12_3():
577
- server_args.attention_backend = "fa3"
578
- elif is_sm100_supported():
579
- server_args.attention_backend = "flashinfer"
580
- elif _is_hip:
581
- head_num = self.model_config.get_num_kv_heads(self.tp_size)
582
- # TODO current aiter only support head number 16 or 128 head number
583
- if head_num == 128 or head_num == 16:
584
- server_args.attention_backend = "aiter"
585
- else:
586
- server_args.attention_backend = "triton"
587
- elif _is_npu:
588
- server_args.attention_backend = "ascend"
589
- else:
590
- server_args.attention_backend = "triton"
591
- log_info_on_rank0(
592
- logger,
593
- f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
594
- )
595
- elif self.use_mla_backend:
596
- if server_args.device != "cpu":
597
- if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
598
- logger.info(
599
- f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
600
- )
601
- else:
602
- raise ValueError(
603
- f"Invalid attention backend for MLA: {server_args.attention_backend}"
604
- )
605
- else:
606
- if server_args.attention_backend != "intel_amx":
607
- raise ValueError(
608
- "MLA optimization not supported on CPU except for intel_amx backend."
609
- )
610
-
611
- if (
612
- server_args.attention_backend == "fa3"
613
- and server_args.kv_cache_dtype == "fp8_e5m2"
614
- ):
615
- logger.warning(
616
- "FlashAttention3 only supports fp8_e4m3 if using FP8; "
617
- "Setting attention backend to triton."
618
- )
619
- server_args.attention_backend = "triton"
620
-
621
500
  if server_args.enable_double_sparsity:
622
501
  logger.info(
623
502
  "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
@@ -643,37 +522,12 @@ class ModelRunner:
643
522
  if not server_args.disable_chunked_prefix_cache:
644
523
  log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
645
524
 
646
- if server_args.attention_backend == "aiter":
647
- if self.model_config.context_len > 8192:
648
- self.mem_fraction_static *= 0.85
649
-
650
- if (
651
- server_args.enable_hierarchical_cache
652
- and server_args.hicache_io_backend == "kernel"
653
- ):
654
- # fix for the compatibility issue with FlashAttention3 decoding and HiCache kernel backend
655
- if server_args.decode_attention_backend is None:
656
- if not self.use_mla_backend:
657
- server_args.decode_attention_backend = (
658
- "flashinfer" if is_flashinfer_available() else "triton"
659
- )
660
- else:
661
- server_args.decode_attention_backend = (
662
- "flashinfer" if is_sm100_supported() else "triton"
663
- )
664
- elif server_args.decode_attention_backend == "fa3":
665
- server_args.hicache_io_backend = "direct"
666
- logger.warning(
667
- "FlashAttention3 decode backend is not compatible with hierarchical cache. "
668
- "Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
669
- )
670
-
671
525
  if self.model_config.hf_config.model_type == "qwen3_vl_moe":
672
526
  if (
673
527
  quantization_config := getattr(
674
528
  self.model_config.hf_config, "quantization_config", None
675
529
  )
676
- ) is not None:
530
+ ) is not None and "weight_block_size" in quantization_config:
677
531
  weight_block_size_n = quantization_config["weight_block_size"][0]
678
532
 
679
533
  if self.tp_size % self.moe_ep_size != 0:
@@ -858,8 +712,6 @@ class ModelRunner:
858
712
  self.model_config = adjust_config_with_unaligned_cpu_tp(
859
713
  self.model_config, self.load_config, self.tp_size
860
714
  )
861
- if self.server_args.load_format == "gguf":
862
- monkey_patch_vllm_gguf_config()
863
715
 
864
716
  if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
865
717
  if self.tp_rank == 0:
@@ -32,7 +32,6 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
32
32
  from sglang.srt.distributed.device_communicators.pynccl_allocator import (
33
33
  set_graph_pool_id,
34
34
  )
35
- from sglang.srt.distributed.parallel_state import graph_capture
36
35
  from sglang.srt.layers.dp_attention import (
37
36
  DpPaddingMode,
38
37
  get_attention_tp_rank,
@@ -250,6 +249,9 @@ class PiecewiseCudaGraphRunner:
250
249
  lora_ids=None,
251
250
  )
252
251
 
252
+ # Attention backend
253
+ self.model_runner.attn_backend.init_forward_metadata(forward_batch)
254
+
253
255
  with set_forward_context(forward_batch, self.attention_layers):
254
256
  _ = self.model_runner.model.forward(
255
257
  forward_batch.input_ids,
@@ -262,9 +264,14 @@ class PiecewiseCudaGraphRunner:
262
264
 
263
265
  def can_run(self, forward_batch: ForwardBatch):
264
266
  num_tokens = len(forward_batch.input_ids)
265
- # TODO(yuwei): support return logprob
267
+ # TODO(yuwei): support return input_ids' logprob
266
268
  if forward_batch.return_logprob:
267
- return False
269
+ for start_len, seq_len in zip(
270
+ forward_batch.extend_logprob_start_lens_cpu,
271
+ forward_batch.extend_seq_lens_cpu,
272
+ ):
273
+ if start_len is not None and start_len < seq_len:
274
+ return False
268
275
  if num_tokens <= self.max_num_tokens:
269
276
  return True
270
277
  return False
@@ -273,10 +280,10 @@ class PiecewiseCudaGraphRunner:
273
280
  # Trigger CUDA graph capture for specific shapes.
274
281
  # Capture the large shapes first so that the smaller shapes
275
282
  # can reuse the memory pool allocated for the large shapes.
276
- with freeze_gc(
277
- self.model_runner.server_args.enable_cudagraph_gc
278
- ), graph_capture() as graph_capture_context:
279
- self.stream = graph_capture_context.stream
283
+ with freeze_gc(self.model_runner.server_args.enable_cudagraph_gc):
284
+ if self.model_runner.tp_group.ca_comm is not None:
285
+ old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
286
+ self.model_runner.tp_group.ca_comm.disabled = True
280
287
  avail_mem = get_available_gpu_memory(
281
288
  self.model_runner.device,
282
289
  self.model_runner.gpu_id,
@@ -304,9 +311,10 @@ class PiecewiseCudaGraphRunner:
304
311
 
305
312
  # Save gemlite cache after each capture
306
313
  save_gemlite_cache()
314
+ if self.model_runner.tp_group.ca_comm is not None:
315
+ self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
307
316
 
308
317
  def capture_one_batch_size(self, num_tokens: int):
309
- stream = self.stream
310
318
  bs = 1
311
319
 
312
320
  # Graph inputs
@@ -370,9 +378,6 @@ class PiecewiseCudaGraphRunner:
370
378
  if lora_ids is not None:
371
379
  self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
372
380
 
373
- # # Attention backend
374
- self.model_runner.attn_backend.init_forward_metadata(forward_batch)
375
-
376
381
  # Run and capture
377
382
  def run_once():
378
383
  # Clean intermediate result cache for DP attention
@@ -438,7 +443,7 @@ class PiecewiseCudaGraphRunner:
438
443
  out_cache_loc=out_cache_loc,
439
444
  seq_lens_sum=forward_batch.seq_lens_sum,
440
445
  encoder_lens=forward_batch.encoder_lens,
441
- return_logprob=forward_batch.return_logprob,
446
+ return_logprob=False,
442
447
  extend_seq_lens=forward_batch.extend_seq_lens,
443
448
  extend_prefix_lens=forward_batch.extend_prefix_lens,
444
449
  extend_start_loc=forward_batch.extend_start_loc,
@@ -474,6 +479,9 @@ class PiecewiseCudaGraphRunner:
474
479
  forward_batch: ForwardBatch,
475
480
  **kwargs,
476
481
  ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
482
+ if self.model_runner.tp_group.ca_comm is not None:
483
+ old_ca_disable = self.model_runner.tp_group.ca_comm.disabled
484
+ self.model_runner.tp_group.ca_comm.disabled = True
477
485
  static_forward_batch = self.replay_prepare(forward_batch, **kwargs)
478
486
  # Replay
479
487
  with set_forward_context(static_forward_batch, self.attention_layers):
@@ -499,6 +507,8 @@ class PiecewiseCudaGraphRunner:
499
507
  raise NotImplementedError(
500
508
  "PPProxyTensors is not supported in PiecewiseCudaGraphRunner yet."
501
509
  )
510
+ if self.model_runner.tp_group.ca_comm is not None:
511
+ self.model_runner.tp_group.ca_comm.disabled = old_ca_disable
502
512
 
503
513
  def get_spec_info(self, num_tokens: int):
504
514
  spec_info = None