sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -7,33 +7,17 @@ import torch
7
7
  from torch import nn
8
8
  from transformers import PretrainedConfig
9
9
 
10
- from sglang.srt.distributed import (
11
- get_tensor_model_parallel_rank,
12
- get_tensor_model_parallel_world_size,
13
- split_tensor_along_last_dim,
14
- tensor_model_parallel_all_gather,
15
- )
10
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
16
11
  from sglang.srt.layers.layernorm import RMSNorm
17
- from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
18
12
  from sglang.srt.layers.logits_processor import LogitsProcessor
19
- from sglang.srt.layers.pooler import Pooler, PoolingType
20
13
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
21
- from sglang.srt.layers.radix_attention import RadixAttention
22
- from sglang.srt.layers.rotary_embedding import get_rope
23
14
  from sglang.srt.layers.vocab_parallel_embedding import (
24
15
  ParallelLMHead,
25
16
  VocabParallelEmbedding,
26
17
  )
27
18
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
28
19
  from sglang.srt.model_loader.weight_utils import default_weight_loader
29
- from sglang.srt.models.mimo import MiMoForCausalLM
30
- from sglang.srt.models.qwen2 import (
31
- Qwen2Attention,
32
- Qwen2DecoderLayer,
33
- Qwen2MLP,
34
- Qwen2Model,
35
- )
36
- from sglang.srt.utils import add_prefix
20
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer
37
21
 
38
22
 
39
23
  class MiMoMultiTokenPredictorLayer(nn.Module):
@@ -1,4 +1,4 @@
1
- from typing import Dict, Tuple
1
+ from typing import Dict, Optional, Tuple, Type
2
2
 
3
3
 
4
4
  class StreamingParseResult:
@@ -32,17 +32,26 @@ class BaseReasoningFormatDetector:
32
32
  One-time parsing: Detects and parses reasoning sections in the provided text.
33
33
  Returns both reasoning content and normal text separately.
34
34
  """
35
- text = text.replace(self.think_start_token, "").strip()
36
- if self.think_end_token not in text:
35
+ in_reasoning = self._in_reasoning or text.startswith(self.think_start_token)
36
+
37
+ if not in_reasoning:
38
+ return StreamingParseResult(normal_text=text)
39
+
40
+ # The text is considered to be in a reasoning block.
41
+ processed_text = text.replace(self.think_start_token, "").strip()
42
+
43
+ if self.think_end_token not in processed_text:
37
44
  # Assume reasoning was truncated before `</think>` token
38
- return StreamingParseResult(reasoning_text=text)
45
+ return StreamingParseResult(reasoning_text=processed_text)
39
46
 
40
47
  # Extract reasoning content
41
- splits = text.split(self.think_end_token, maxsplit=1)
48
+ splits = processed_text.split(self.think_end_token, maxsplit=1)
42
49
  reasoning_text = splits[0]
43
- text = splits[1].strip()
50
+ normal_text = splits[1].strip()
44
51
 
45
- return StreamingParseResult(normal_text=text, reasoning_text=reasoning_text)
52
+ return StreamingParseResult(
53
+ normal_text=normal_text, reasoning_text=reasoning_text
54
+ )
46
55
 
47
56
  def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
48
57
  """
@@ -61,6 +70,7 @@ class BaseReasoningFormatDetector:
61
70
  if not self.stripped_think_start and self.think_start_token in current_text:
62
71
  current_text = current_text.replace(self.think_start_token, "")
63
72
  self.stripped_think_start = True
73
+ self._in_reasoning = True
64
74
 
65
75
  # Handle end of reasoning block
66
76
  if self._in_reasoning and self.think_end_token in current_text:
@@ -131,11 +141,11 @@ class Qwen3Detector(BaseReasoningFormatDetector):
131
141
  """
132
142
 
133
143
  def __init__(self, stream_reasoning: bool = True):
134
- # Qwen3 is assumed to be reasoning until `</think>` token
144
+ # Qwen3 won't be in reasoning mode when user passes `enable_thinking=False`
135
145
  super().__init__(
136
146
  "<think>",
137
147
  "</think>",
138
- force_reasoning=True,
148
+ force_reasoning=False,
139
149
  stream_reasoning=stream_reasoning,
140
150
  )
141
151
 
@@ -151,12 +161,12 @@ class ReasoningParser:
151
161
  If True, streams reasoning content as it arrives.
152
162
  """
153
163
 
154
- DetectorMap: Dict[str, BaseReasoningFormatDetector] = {
164
+ DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = {
155
165
  "deepseek-r1": DeepSeekR1Detector,
156
166
  "qwen3": Qwen3Detector,
157
167
  }
158
168
 
159
- def __init__(self, model_type: str = None, stream_reasoning: bool = True):
169
+ def __init__(self, model_type: Optional[str] = None, stream_reasoning: bool = True):
160
170
  if not model_type:
161
171
  raise ValueError("Model type must be specified")
162
172
 
sglang/srt/server_args.py CHANGED
@@ -47,6 +47,7 @@ class ServerArgs:
47
47
  tokenizer_mode: str = "auto"
48
48
  skip_tokenizer_init: bool = False
49
49
  load_format: str = "auto"
50
+ model_loader_extra_config: str = "{}"
50
51
  trust_remote_code: bool = False
51
52
  dtype: str = "auto"
52
53
  kv_cache_dtype: str = "auto"
@@ -152,6 +153,7 @@ class ServerArgs:
152
153
  ep_size: int = 1
153
154
  enable_ep_moe: bool = False
154
155
  enable_deepep_moe: bool = False
156
+ enable_flashinfer_moe: bool = False
155
157
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
156
158
  ep_num_redundant_experts: int = 0
157
159
  ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
@@ -234,6 +236,10 @@ class ServerArgs:
234
236
  num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
235
237
  pdlb_url: Optional[str] = None
236
238
 
239
+ # For model weight update
240
+ custom_weight_loader: Optional[List[str]] = None
241
+ weight_loader_disable_mmap: bool = False
242
+
237
243
  def __post_init__(self):
238
244
  # Expert parallelism
239
245
  if self.enable_ep_moe:
@@ -241,7 +247,15 @@ class ServerArgs:
241
247
  logger.warning(
242
248
  f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
243
249
  )
244
-
250
+ if self.enable_flashinfer_moe:
251
+ assert (
252
+ self.quantization == "modelopt_fp4"
253
+ ), "modelopt_fp4 quantization is required for Flashinfer MOE"
254
+ os.environ["TRTLLM_ENABLE_PDL"] = "1"
255
+ self.disable_shared_experts_fusion = True
256
+ logger.warning(
257
+ f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
258
+ )
245
259
  # Set missing default values
246
260
  if self.tokenizer_path is None:
247
261
  self.tokenizer_path = self.model_path
@@ -384,7 +398,6 @@ class ServerArgs:
384
398
  ), "Please enable dp attention when setting enable_dp_attention. "
385
399
 
386
400
  # DeepEP MoE
387
- self.enable_sp_layernorm = False
388
401
  if self.enable_deepep_moe:
389
402
  if self.deepep_mode == "auto":
390
403
  assert (
@@ -394,9 +407,6 @@ class ServerArgs:
394
407
  logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
395
408
  self.disable_cuda_graph = True
396
409
  self.ep_size = self.tp_size
397
- self.enable_sp_layernorm = (
398
- self.dp_size < self.tp_size if self.enable_dp_attention else True
399
- )
400
410
  logger.warning(
401
411
  f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
402
412
  )
@@ -538,6 +548,9 @@ class ServerArgs:
538
548
  "1" if self.disable_outlines_disk_cache else "0"
539
549
  )
540
550
 
551
+ if self.custom_weight_loader is None:
552
+ self.custom_weight_loader = []
553
+
541
554
  def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
542
555
  larger_tp = max(decode_tp, prefill_tp)
543
556
  smaller_tp = min(decode_tp, prefill_tp)
@@ -551,6 +564,7 @@ class ServerArgs:
551
564
  # Model and port args
552
565
  parser.add_argument(
553
566
  "--model-path",
567
+ "--model",
554
568
  type=str,
555
569
  help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
556
570
  required=True,
@@ -620,6 +634,13 @@ class ServerArgs:
620
634
  "layer before loading another to make the peak memory envelope "
621
635
  "smaller.",
622
636
  )
637
+ parser.add_argument(
638
+ "--model-loader-extra-config",
639
+ type=str,
640
+ help="Extra config for model loader. "
641
+ "This will be passed to the model loader corresponding to the chosen load_format.",
642
+ default=ServerArgs.model_loader_extra_config,
643
+ )
623
644
  parser.add_argument(
624
645
  "--trust-remote-code",
625
646
  action="store_true",
@@ -1160,6 +1181,11 @@ class ServerArgs:
1160
1181
  action="store_true",
1161
1182
  help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
1162
1183
  )
1184
+ parser.add_argument(
1185
+ "--enable-flashinfer-moe",
1186
+ action="store_true",
1187
+ help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1188
+ )
1163
1189
  parser.add_argument(
1164
1190
  "--enable-deepep-moe",
1165
1191
  action="store_true",
@@ -1576,6 +1602,18 @@ class ServerArgs:
1576
1602
  default=None,
1577
1603
  help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
1578
1604
  )
1605
+ parser.add_argument(
1606
+ "--custom-weight-loader",
1607
+ type=str,
1608
+ nargs="*",
1609
+ default=None,
1610
+ help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
1611
+ )
1612
+ parser.add_argument(
1613
+ "--weight-loader-disable-mmap",
1614
+ action="store_true",
1615
+ help="Disable mmap while loading weight using safetensors.",
1616
+ )
1579
1617
 
1580
1618
  @classmethod
1581
1619
  def from_cli_args(cls, args: argparse.Namespace):
@@ -1663,6 +1701,9 @@ class PortArgs:
1663
1701
  # The ipc filename for rpc call between Engine and Scheduler
1664
1702
  rpc_ipc_name: str
1665
1703
 
1704
+ # The ipc filename for Scheduler to send metrics
1705
+ metrics_ipc_name: str
1706
+
1666
1707
  @staticmethod
1667
1708
  def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1668
1709
  port = server_args.port + random.randint(100, 1000)
@@ -1682,6 +1723,7 @@ class PortArgs:
1682
1723
  detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1683
1724
  nccl_port=port,
1684
1725
  rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1726
+ metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1685
1727
  )
1686
1728
  else:
1687
1729
  # DP attention. Use TCP + port to handle both single-node and multi-node.
@@ -1700,11 +1742,10 @@ class PortArgs:
1700
1742
  dist_init_host, dist_init_port = dist_init_addr
1701
1743
  port_base = int(dist_init_port) + 1
1702
1744
  if dp_rank is None:
1703
- scheduler_input_port = (
1704
- port_base + 3
1705
- ) # TokenizerManager to DataParallelController
1745
+ # TokenizerManager to DataParallelController
1746
+ scheduler_input_port = port_base + 4
1706
1747
  else:
1707
- scheduler_input_port = port_base + 3 + 1 + dp_rank
1748
+ scheduler_input_port = port_base + 4 + 1 + dp_rank
1708
1749
 
1709
1750
  return PortArgs(
1710
1751
  tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
@@ -1712,6 +1753,7 @@ class PortArgs:
1712
1753
  detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
1713
1754
  nccl_port=port,
1714
1755
  rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1756
+ metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
1715
1757
  )
1716
1758
 
1717
1759
 
@@ -20,6 +20,12 @@ from sglang.srt.model_executor.forward_batch_info import (
20
20
  ForwardMode,
21
21
  )
22
22
  from sglang.srt.speculative.eagle_utils import EagleDraftInput
23
+ from sglang.srt.utils import (
24
+ require_attn_tp_gather,
25
+ require_gathered_buffer,
26
+ require_mlp_sync,
27
+ require_mlp_tp_gather,
28
+ )
23
29
 
24
30
  if TYPE_CHECKING:
25
31
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -38,6 +44,12 @@ class EAGLEDraftCudaGraphRunner:
38
44
  self.output_buffers = {}
39
45
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
40
46
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
47
+ self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
48
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
49
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
50
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
51
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
52
+ self.dp_size = self.model_runner.dp_size
41
53
  self.tp_size = self.model_runner.tp_size
42
54
  self.topk = model_runner.server_args.speculative_eagle_topk
43
55
  self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -53,7 +65,9 @@ class EAGLEDraftCudaGraphRunner:
53
65
  # Attention backend
54
66
  self.max_bs = max(self.capture_bs)
55
67
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
56
- self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
68
+ self.model_runner.draft_attn_backend.init_cuda_graph_state(
69
+ self.max_bs, self.max_num_token
70
+ )
57
71
  self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
58
72
  0
59
73
  ].get_cuda_graph_seq_len_fill_value()
@@ -78,10 +92,32 @@ class EAGLEDraftCudaGraphRunner:
78
92
  self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
79
93
  self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
80
94
  self.hidden_states = torch.zeros(
81
- (self.max_num_token, self.model_runner.model_config.hidden_size),
95
+ (self.max_bs, self.model_runner.model_config.hidden_size),
82
96
  dtype=self.model_runner.dtype,
83
97
  )
84
98
 
99
+ if self.require_gathered_buffer:
100
+ self.gathered_buffer = torch.zeros(
101
+ (
102
+ self.max_num_token,
103
+ self.model_runner.model_config.hidden_size,
104
+ ),
105
+ dtype=self.model_runner.dtype,
106
+ )
107
+ if self.require_mlp_tp_gather:
108
+ self.global_num_tokens_gpu = torch.zeros(
109
+ (self.dp_size,), dtype=torch.int32
110
+ )
111
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
112
+ (self.dp_size,), dtype=torch.int32
113
+ )
114
+ else:
115
+ assert self.require_attn_tp_gather
116
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
117
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
118
+ (1,), dtype=torch.int32
119
+ )
120
+
85
121
  # Capture
86
122
  try:
87
123
  with model_capture_mode():
@@ -92,11 +128,24 @@ class EAGLEDraftCudaGraphRunner:
92
128
  )
93
129
 
94
130
  def can_run(self, forward_batch: ForwardBatch):
131
+ if self.require_mlp_tp_gather:
132
+ cuda_graph_bs = (
133
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
134
+ if self.model_runner.spec_algorithm.is_eagle()
135
+ else sum(forward_batch.global_num_tokens_cpu)
136
+ )
137
+ else:
138
+ cuda_graph_bs = forward_batch.batch_size
139
+
95
140
  is_bs_supported = (
96
- forward_batch.batch_size in self.graphs
141
+ cuda_graph_bs in self.graphs
97
142
  if self.disable_padding
98
- else forward_batch.batch_size <= self.max_bs
143
+ else cuda_graph_bs <= self.max_bs
99
144
  )
145
+
146
+ if self.require_mlp_sync:
147
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
148
+
100
149
  return is_bs_supported
101
150
 
102
151
  def capture(self):
@@ -116,8 +165,58 @@ class EAGLEDraftCudaGraphRunner:
116
165
  topk_index = self.topk_index[:num_seqs]
117
166
  hidden_states = self.hidden_states[:num_seqs]
118
167
 
168
+ if self.require_mlp_tp_gather:
169
+ self.global_num_tokens_gpu.copy_(
170
+ torch.tensor(
171
+ [
172
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
173
+ for i in range(self.dp_size)
174
+ ],
175
+ dtype=torch.int32,
176
+ device=self.input_ids.device,
177
+ )
178
+ )
179
+ self.global_num_tokens_for_logprob_gpu.copy_(
180
+ torch.tensor(
181
+ [
182
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
183
+ for i in range(self.dp_size)
184
+ ],
185
+ dtype=torch.int32,
186
+ device=self.input_ids.device,
187
+ )
188
+ )
189
+ global_num_tokens = self.global_num_tokens_gpu
190
+ gathered_buffer = self.gathered_buffer[:num_tokens]
191
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
192
+ elif self.require_attn_tp_gather:
193
+ self.global_num_tokens_gpu.copy_(
194
+ torch.tensor(
195
+ [num_tokens],
196
+ dtype=torch.int32,
197
+ device=self.input_ids.device,
198
+ )
199
+ )
200
+ self.global_num_tokens_for_logprob_gpu.copy_(
201
+ torch.tensor(
202
+ [num_tokens],
203
+ dtype=torch.int32,
204
+ device=self.input_ids.device,
205
+ )
206
+ )
207
+ global_num_tokens = self.global_num_tokens_gpu
208
+ gathered_buffer = self.gathered_buffer[:num_tokens]
209
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
210
+ else:
211
+ global_num_tokens = None
212
+ gathered_buffer = None
213
+ global_num_tokens_for_logprob = None
214
+
119
215
  spec_info = EagleDraftInput(
120
- topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
216
+ topk_p=topk_p,
217
+ topk_index=topk_index,
218
+ hidden_states=hidden_states,
219
+ capture_hidden_mode=CaptureHiddenMode.LAST,
121
220
  )
122
221
 
123
222
  # Forward batch
@@ -133,11 +232,14 @@ class EAGLEDraftCudaGraphRunner:
133
232
  seq_lens_sum=seq_lens.sum().item(),
134
233
  return_logprob=False,
135
234
  positions=positions,
235
+ global_num_tokens_gpu=global_num_tokens,
236
+ gathered_buffer=gathered_buffer,
136
237
  spec_algorithm=self.model_runner.spec_algorithm,
137
238
  spec_info=spec_info,
138
239
  capture_hidden_mode=(
139
240
  spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
140
241
  ),
242
+ global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
141
243
  )
142
244
 
143
245
  # Attention backend
@@ -147,6 +249,9 @@ class EAGLEDraftCudaGraphRunner:
147
249
 
148
250
  # Run and capture
149
251
  def run_once():
252
+ # Clean intermediate result cache for DP attention
253
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
254
+
150
255
  # Backup two fields, which will be modified in-place in `draft_forward`.
151
256
  output_cache_loc_backup = forward_batch.out_cache_loc
152
257
  hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -184,12 +289,19 @@ class EAGLEDraftCudaGraphRunner:
184
289
  raw_num_token = raw_bs * self.num_tokens_per_bs
185
290
 
186
291
  # Pad
187
- index = bisect.bisect_left(self.capture_bs, raw_bs)
292
+ if self.require_mlp_tp_gather:
293
+ total_batch_size = (
294
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
295
+ if self.model_runner.spec_algorithm.is_eagle()
296
+ else sum(forward_batch.global_num_tokens_cpu)
297
+ )
298
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
299
+ else:
300
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
188
301
  bs = self.capture_bs[index]
189
302
  if bs != raw_bs:
190
- self.seq_lens.fill_(1)
303
+ self.seq_lens.fill_(self.seq_len_fill_value)
191
304
  self.out_cache_loc.zero_()
192
- self.positions.zero_()
193
305
 
194
306
  num_tokens = bs * self.num_tokens_per_bs
195
307
 
@@ -204,6 +316,13 @@ class EAGLEDraftCudaGraphRunner:
204
316
  self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
205
317
  self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
206
318
 
319
+ if self.require_gathered_buffer:
320
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
321
+ self.global_num_tokens_for_logprob_gpu.copy_(
322
+ forward_batch.global_num_tokens_for_logprob_gpu
323
+ )
324
+ forward_batch.gathered_buffer = self.gathered_buffer
325
+
207
326
  # Attention backend
208
327
  if bs != raw_bs:
209
328
  forward_batch.batch_size = bs
@@ -212,14 +331,16 @@ class EAGLEDraftCudaGraphRunner:
212
331
  forward_batch.positions = self.positions[:num_tokens]
213
332
 
214
333
  # Special handle for seq_len_cpu used when flashinfer mla is used
215
- if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
216
- self.seq_lens_cpu.fill_(1)
334
+ if forward_batch.seq_lens_cpu is not None:
335
+ if bs != raw_bs:
336
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
217
337
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
218
338
  forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
219
339
 
220
340
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
221
341
  forward_batch, bs
222
342
  )
343
+ # TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
223
344
 
224
345
  # Replay
225
346
  self.graphs[bs].replay()