sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Callable
5
5
 
6
6
  import torch
7
7
 
8
+ from sglang.srt.layers.dp_attention import DPPaddingMode
8
9
  from sglang.srt.model_executor.cuda_graph_runner import (
9
10
  CUDA_GRAPH_CAPTURE_FAILED_MSG,
10
11
  CudaGraphRunner,
@@ -109,13 +110,6 @@ class EAGLEDraftExtendCudaGraphRunner:
109
110
  )
110
111
 
111
112
  if self.require_gathered_buffer:
112
- self.gathered_buffer = torch.zeros(
113
- (
114
- self.max_num_token,
115
- self.model_runner.model_config.hidden_size,
116
- ),
117
- dtype=self.model_runner.dtype,
118
- )
119
113
  if self.require_mlp_tp_gather:
120
114
  self.global_num_tokens_gpu = torch.zeros(
121
115
  (self.dp_size,), dtype=torch.int32
@@ -123,12 +117,31 @@ class EAGLEDraftExtendCudaGraphRunner:
123
117
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
124
118
  (self.dp_size,), dtype=torch.int32
125
119
  )
120
+ self.gathered_buffer = torch.zeros(
121
+ (
122
+ self.max_num_token * self.dp_size,
123
+ self.model_runner.model_config.hidden_size,
124
+ ),
125
+ dtype=self.model_runner.dtype,
126
+ )
126
127
  else:
127
128
  assert self.require_attn_tp_gather
128
129
  self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
129
130
  self.global_num_tokens_for_logprob_gpu = torch.zeros(
130
131
  (1,), dtype=torch.int32
131
132
  )
133
+ self.gathered_buffer = torch.zeros(
134
+ (
135
+ self.max_num_token,
136
+ self.model_runner.model_config.hidden_size,
137
+ ),
138
+ dtype=self.model_runner.dtype,
139
+ )
140
+ else:
141
+ self.global_num_tokens_gpu = None
142
+ self.global_num_tokens_for_logprob_gpu = None
143
+ self.gathered_buffer = None
144
+
132
145
  # Capture
133
146
  try:
134
147
  with model_capture_mode():
@@ -141,9 +154,9 @@ class EAGLEDraftExtendCudaGraphRunner:
141
154
  def can_run(self, forward_batch: ForwardBatch):
142
155
  if self.require_mlp_tp_gather:
143
156
  cuda_graph_bs = (
144
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
157
+ max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
145
158
  if self.model_runner.spec_algorithm.is_eagle()
146
- else sum(forward_batch.global_num_tokens_cpu)
159
+ else max(forward_batch.global_num_tokens_cpu)
147
160
  )
148
161
  else:
149
162
  cuda_graph_bs = forward_batch.seq_lens.numel()
@@ -180,27 +193,19 @@ class EAGLEDraftExtendCudaGraphRunner:
180
193
  if self.require_mlp_tp_gather:
181
194
  self.global_num_tokens_gpu.copy_(
182
195
  torch.tensor(
183
- [
184
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
185
- for i in range(self.dp_size)
186
- ],
196
+ [num_tokens] * self.dp_size,
187
197
  dtype=torch.int32,
188
198
  device=self.input_ids.device,
189
199
  )
190
200
  )
191
201
  self.global_num_tokens_for_logprob_gpu.copy_(
192
202
  torch.tensor(
193
- [
194
- num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
195
- for i in range(self.dp_size)
196
- ],
203
+ [bs] * self.dp_size,
197
204
  dtype=torch.int32,
198
205
  device=self.input_ids.device,
199
206
  )
200
207
  )
201
- global_num_tokens = self.global_num_tokens_gpu
202
- gathered_buffer = self.gathered_buffer[:num_tokens]
203
- global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
208
+ gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
204
209
  elif self.require_attn_tp_gather:
205
210
  self.global_num_tokens_gpu.copy_(
206
211
  torch.tensor(
@@ -211,18 +216,14 @@ class EAGLEDraftExtendCudaGraphRunner:
211
216
  )
212
217
  self.global_num_tokens_for_logprob_gpu.copy_(
213
218
  torch.tensor(
214
- [num_tokens],
219
+ [bs],
215
220
  dtype=torch.int32,
216
221
  device=self.input_ids.device,
217
222
  )
218
223
  )
219
- global_num_tokens = self.global_num_tokens_gpu
220
224
  gathered_buffer = self.gathered_buffer[:num_tokens]
221
- global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
222
225
  else:
223
- global_num_tokens = None
224
226
  gathered_buffer = None
225
- global_num_tokens_for_logprob = None
226
227
 
227
228
  spec_info = EagleDraftInput(
228
229
  hidden_states=hidden_states,
@@ -243,8 +244,9 @@ class EAGLEDraftExtendCudaGraphRunner:
243
244
  seq_lens_sum=seq_lens.sum().item(),
244
245
  return_logprob=False,
245
246
  positions=positions,
246
- global_num_tokens_gpu=global_num_tokens,
247
- global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
247
+ global_num_tokens_gpu=self.global_num_tokens_gpu,
248
+ global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
249
+ dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
248
250
  gathered_buffer=gathered_buffer,
249
251
  spec_algorithm=self.model_runner.spec_algorithm,
250
252
  spec_info=spec_info,
@@ -306,12 +308,13 @@ class EAGLEDraftExtendCudaGraphRunner:
306
308
  raw_bs = forward_batch.batch_size
307
309
  num_tokens = forward_batch.input_ids.shape[0]
308
310
  if self.require_mlp_tp_gather:
309
- total_batch_size = (
310
- sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
311
+ max_num_tokens = max(forward_batch.global_num_tokens_cpu)
312
+ max_batch_size = (
313
+ max_num_tokens // self.num_tokens_per_bs
311
314
  if self.model_runner.spec_algorithm.is_eagle()
312
- else sum(forward_batch.global_num_tokens_cpu)
315
+ else max_num_tokens
313
316
  )
314
- index = bisect.bisect_left(self.capture_bs, total_batch_size)
317
+ index = bisect.bisect_left(self.capture_bs, max_batch_size)
315
318
  else:
316
319
  index = bisect.bisect_left(self.capture_bs, raw_bs)
317
320
 
@@ -334,12 +337,10 @@ class EAGLEDraftExtendCudaGraphRunner:
334
337
  self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
335
338
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
336
339
 
340
+ # TODO(ch-wan): support num_token_non_padded
337
341
  if self.require_gathered_buffer:
338
- self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
339
- self.global_num_tokens_for_logprob_gpu.copy_(
340
- forward_batch.global_num_tokens_for_logprob_gpu
341
- )
342
- forward_batch.gathered_buffer = self.gathered_buffer
342
+ self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
343
+ self.global_num_tokens_for_logprob_gpu.fill_(bs)
343
344
 
344
345
  if forward_batch.seq_lens_cpu is not None:
345
346
  if bs != raw_bs:
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import copy
3
4
  import logging
4
5
  import os
5
6
  import time
@@ -70,9 +71,20 @@ class EagleDraftInput:
70
71
  kv_indptr: torch.Tensor = None
71
72
  kv_indices: torch.Tensor = None
72
73
 
74
+ # Shape info for padding
75
+ num_tokens_per_batch: int = -1
76
+ num_tokens_for_logprob_per_batch: int = -1
77
+
78
+ # Inputs for draft extend
79
+ # shape: (b,)
80
+ seq_lens_for_draft_extend: torch.Tensor = None
81
+ req_pool_indices_for_draft_extend: torch.Tensor = None
82
+
73
83
  def prepare_for_extend(self, batch: ScheduleBatch):
84
+
74
85
  if batch.forward_mode.is_idle():
75
86
  return
87
+
76
88
  # Prefill only generate 1 token.
77
89
  assert len(self.verified_id) == len(batch.seq_lens)
78
90
 
@@ -94,7 +106,7 @@ class EagleDraftInput:
94
106
  capture_hidden_mode: CaptureHiddenMode,
95
107
  ):
96
108
  return cls(
97
- verified_id=None,
109
+ verified_id=torch.empty((0,), device=device, dtype=torch.int32),
98
110
  hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
99
111
  topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
100
112
  topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
@@ -108,7 +120,10 @@ class EagleDraftInput:
108
120
  batch: ScheduleBatch,
109
121
  speculative_num_steps: int,
110
122
  ):
111
- batch.forward_mode = ForwardMode.DRAFT_EXTEND
123
+
124
+ if batch.forward_mode.is_idle():
125
+ return
126
+
112
127
  batch.input_ids = self.verified_id
113
128
  batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
114
129
  batch.extend_num_tokens = sum(batch.extend_lens)
@@ -315,7 +330,7 @@ class EagleVerifyInput:
315
330
  def verify(
316
331
  self,
317
332
  batch: ScheduleBatch,
318
- logits_output: torch.Tensor,
333
+ logits_output: LogitsProcessorOutput,
319
334
  token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
320
335
  page_size: int,
321
336
  vocab_mask: Optional[torch.Tensor] = None, # For grammar
@@ -362,6 +377,11 @@ class EagleVerifyInput:
362
377
  )
363
378
  accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
364
379
 
380
+ if bs != len(sampling_info):
381
+ sampling_info = copy.deepcopy(sampling_info)
382
+ # NOTE: retrive_index are the indices of the requests that are kept.
383
+ sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
384
+
365
385
  # Apply the custom logit processors if registered in the sampling info.
366
386
  if sampling_info.has_custom_logit_processor:
367
387
  apply_custom_logit_processor(
@@ -593,13 +613,14 @@ class EagleVerifyInput:
593
613
  batch.out_cache_loc = tgt_cache_loc
594
614
  batch.seq_lens.add_(accept_length + 1)
595
615
 
596
- draft_input = EagleDraftInput()
597
- draft_input.hidden_states = batch.spec_info.hidden_states[accept_index]
598
- draft_input.verified_id = verified_id
599
- draft_input.accept_length = accept_length
600
- draft_input.accept_length_cpu = accept_length.tolist()
601
- draft_input.seq_lens_for_draft_extend = batch.seq_lens
602
- draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
616
+ draft_input = EagleDraftInput(
617
+ hidden_states=batch.spec_info.hidden_states[accept_index],
618
+ verified_id=verified_id,
619
+ accept_length=accept_length,
620
+ accept_length_cpu=accept_length.tolist(),
621
+ seq_lens_for_draft_extend=batch.seq_lens,
622
+ req_pool_indices_for_draft_extend=batch.req_pool_indices,
623
+ )
603
624
 
604
625
  return EagleVerifyOutput(
605
626
  draft_input=draft_input,
@@ -622,7 +643,6 @@ class EagleVerifyInput:
622
643
  batch.seq_lens.add_(accept_length + 1)
623
644
 
624
645
  accept_length_cpu = accept_length.tolist()
625
- draft_input = EagleDraftInput()
626
646
  if len(unfinished_accept_index) > 0:
627
647
  unfinished_accept_index = torch.cat(unfinished_accept_index)
628
648
  unfinished_index_device = torch.tensor(
@@ -653,18 +673,26 @@ class EagleVerifyInput:
653
673
  next_power_of_2(self.draft_token_num),
654
674
  )
655
675
 
656
- draft_input.hidden_states = batch.spec_info.hidden_states[
657
- unfinished_accept_index
658
- ]
659
- draft_input.verified_id = predict[unfinished_accept_index]
660
- draft_input.accept_length_cpu = draft_input_accept_length_cpu
661
- draft_input.accept_length = accept_length[unfinished_index_device]
662
- draft_input.seq_lens_for_draft_extend = batch.seq_lens[
663
- unfinished_index_device
664
- ]
665
- draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
666
- unfinished_index_device
667
- ]
676
+ draft_input = EagleDraftInput(
677
+ hidden_states=batch.spec_info.hidden_states[
678
+ unfinished_accept_index
679
+ ],
680
+ verified_id=predict[unfinished_accept_index],
681
+ accept_length_cpu=draft_input_accept_length_cpu,
682
+ accept_length=accept_length[unfinished_index_device],
683
+ seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
684
+ req_pool_indices_for_draft_extend=batch.req_pool_indices[
685
+ unfinished_index_device
686
+ ],
687
+ )
688
+ else:
689
+ draft_input = EagleDraftInput.create_idle_input(
690
+ device=batch.device,
691
+ hidden_size=batch.model_config.hidden_size,
692
+ dtype=batch.model_config.dtype,
693
+ topk=self.topk,
694
+ capture_hidden_mode=CaptureHiddenMode.LAST,
695
+ )
668
696
 
669
697
  return EagleVerifyOutput(
670
698
  draft_input=draft_input,
@@ -297,7 +297,7 @@ class EAGLEWorker(TpModelWorker):
297
297
 
298
298
  def forward_batch_speculative_generation(
299
299
  self, batch: ScheduleBatch
300
- ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
300
+ ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, int, bool]:
301
301
  """Run speculative decoding forward.
302
302
 
303
303
  NOTE: Many states of batch is modified as you go through. It is not guaranteed that
@@ -325,11 +325,16 @@ class EAGLEWorker(TpModelWorker):
325
325
  self.verify(batch, spec_info)
326
326
  )
327
327
 
328
- if self.check_forward_draft_extend_after_decode(batch):
329
- with self.draft_tp_context(self.draft_model_runner.tp_group):
330
- self.forward_draft_extend_after_decode(
331
- batch,
332
- )
328
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
329
+ # NOTE: We should use `check_forward_draft_extend_after_decode`
330
+ # when DP attention is enabled, but it is slow. Skip it for now.
331
+ if (
332
+ self.server_args.enable_dp_attention
333
+ or batch.spec_info.verified_id.shape[0] > 0
334
+ ):
335
+ # decode is not finished
336
+ self.forward_draft_extend_after_decode(batch)
337
+
333
338
  return (
334
339
  logits_output,
335
340
  verify_output.verified_id,
@@ -339,10 +344,7 @@ class EAGLEWorker(TpModelWorker):
339
344
  )
340
345
 
341
346
  def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
342
- local_need_forward = (
343
- batch.spec_info.verified_id is not None
344
- and batch.spec_info.verified_id.shape[0] > 0
345
- )
347
+ local_need_forward = batch.spec_info.verified_id.shape[0] > 0
346
348
  if not self.server_args.enable_dp_attention:
347
349
  return local_need_forward
348
350
 
@@ -361,7 +363,7 @@ class EAGLEWorker(TpModelWorker):
361
363
 
362
364
  def forward_target_extend(
363
365
  self, batch: ScheduleBatch
364
- ) -> Tuple[LogitsProcessorOutput, List[int], int]:
366
+ ) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
365
367
  """Run the target extend.
366
368
 
367
369
  Args:
@@ -376,7 +378,6 @@ class EAGLEWorker(TpModelWorker):
376
378
  # We need the full hidden states to prefill the KV cache of the draft model.
377
379
  model_worker_batch = batch.get_model_worker_batch()
378
380
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
379
- model_worker_batch.spec_num_draft_tokens = 1
380
381
  logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
381
382
  model_worker_batch
382
383
  )
@@ -508,13 +509,15 @@ class EAGLEWorker(TpModelWorker):
508
509
  self._draft_preprocess_decode(batch)
509
510
 
510
511
  spec_info = batch.spec_info
512
+ assert isinstance(spec_info, EagleDraftInput)
511
513
 
512
514
  spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
515
+ spec_info.num_tokens_per_batch = self.topk
516
+ spec_info.num_tokens_for_logprob_per_batch = self.topk
513
517
  batch.return_hidden_states = False
514
518
 
515
519
  # Get forward batch
516
520
  model_worker_batch = batch.get_model_worker_batch()
517
- model_worker_batch.spec_num_draft_tokens = self.topk
518
521
  assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
519
522
  forward_batch = ForwardBatch.init_new(
520
523
  model_worker_batch, self.draft_model_runner
@@ -527,6 +530,7 @@ class EAGLEWorker(TpModelWorker):
527
530
  forward_batch
528
531
  )
529
532
  else:
533
+ forward_batch.can_run_dp_cuda_graph = False
530
534
  if not forward_batch.forward_mode.is_idle():
531
535
  # Initialize attention backend
532
536
  self.draft_attn_backend.init_forward_metadata(forward_batch)
@@ -578,6 +582,7 @@ class EAGLEWorker(TpModelWorker):
578
582
  def draft_forward(self, forward_batch: ForwardBatch):
579
583
  # Parse args
580
584
  spec_info = forward_batch.spec_info
585
+ assert isinstance(spec_info, EagleDraftInput)
581
586
  out_cache_loc = forward_batch.out_cache_loc
582
587
  topk_p, topk_index, hidden_states = (
583
588
  spec_info.topk_p,
@@ -621,8 +626,8 @@ class EAGLEWorker(TpModelWorker):
621
626
  spec_info.hidden_states = hidden_states
622
627
 
623
628
  # Run forward
624
- logits_output = self.draft_model_runner.model.forward(
625
- forward_batch.input_ids, forward_batch.positions, forward_batch
629
+ logits_output, _ = self.draft_model_runner.forward(
630
+ forward_batch, skip_attn_backend_init=True
626
631
  )
627
632
  self._detect_nan_if_needed(logits_output)
628
633
  probs = torch.softmax(logits_output.next_token_logits, dim=-1)
@@ -642,10 +647,10 @@ class EAGLEWorker(TpModelWorker):
642
647
  else ForwardMode.IDLE
643
648
  )
644
649
  batch.spec_info = spec_info
650
+
645
651
  model_worker_batch = batch.get_model_worker_batch(
646
652
  seq_lens_cpu_cache=spec_info.seq_lens_cpu
647
653
  )
648
- model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
649
654
  assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
650
655
 
651
656
  if batch.has_grammar:
@@ -782,8 +787,8 @@ class EAGLEWorker(TpModelWorker):
782
787
  self,
783
788
  batch: ScheduleBatch,
784
789
  hidden_states: torch.Tensor,
785
- next_token_ids: List[int],
786
- seq_lens_cpu: torch.Tensor,
790
+ next_token_ids: torch.Tensor,
791
+ seq_lens_cpu: Optional[torch.Tensor],
787
792
  ):
788
793
  """Run draft model extend. This API modifies the states of the batch.
789
794
 
@@ -795,6 +800,8 @@ class EAGLEWorker(TpModelWorker):
795
800
  batch.spec_info = EagleDraftInput(
796
801
  hidden_states=hidden_states,
797
802
  verified_id=next_token_ids,
803
+ num_tokens_per_batch=1,
804
+ num_tokens_for_logprob_per_batch=1,
798
805
  )
799
806
  batch.return_hidden_states = False
800
807
  batch.spec_info.prepare_for_extend(batch)
@@ -802,7 +809,6 @@ class EAGLEWorker(TpModelWorker):
802
809
  model_worker_batch = batch.get_model_worker_batch(
803
810
  seq_lens_cpu_cache=seq_lens_cpu
804
811
  )
805
- model_worker_batch.spec_num_draft_tokens = 1
806
812
  forward_batch = ForwardBatch.init_new(
807
813
  model_worker_batch, self.draft_model_runner
808
814
  )
@@ -814,37 +820,45 @@ class EAGLEWorker(TpModelWorker):
814
820
  self.capture_for_decode(logits_output, forward_batch.spec_info)
815
821
 
816
822
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
823
+ assert isinstance(batch.spec_info, EagleDraftInput)
817
824
  # Backup fields that will be modified in-place
818
825
  seq_lens_backup = batch.seq_lens.clone()
819
826
  req_pool_indices_backup = batch.req_pool_indices
820
827
  accept_length_backup = batch.spec_info.accept_length
821
828
  return_logprob_backup = batch.return_logprob
829
+
822
830
  input_is_idle = batch.forward_mode.is_idle()
823
- if not input_is_idle:
824
- # Prepare metadata
825
- if batch.spec_info.verified_id is not None:
826
- batch.spec_info.prepare_extend_after_decode(
827
- batch,
828
- self.speculative_num_steps,
829
- )
830
- else:
831
- batch = batch.copy()
832
- batch.prepare_for_idle()
833
- hidden_size = (
834
- self.model_config.hidden_size * 3
835
- if self.speculative_algorithm.is_eagle3()
836
- else self.model_config.hidden_size
837
- )
838
- batch.spec_info = EagleDraftInput.create_idle_input(
839
- device=self.device,
840
- hidden_size=hidden_size,
841
- dtype=self.model_config.dtype,
842
- topk=self.topk,
843
- capture_hidden_mode=CaptureHiddenMode.LAST,
844
- )
831
+
832
+ if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
833
+ batch = batch.copy()
834
+ batch.prepare_for_idle()
835
+ hidden_size = (
836
+ self.model_config.hidden_size * 3
837
+ if self.speculative_algorithm.is_eagle3()
838
+ else self.model_config.hidden_size
839
+ )
840
+ batch.spec_info = EagleDraftInput.create_idle_input(
841
+ device=self.device,
842
+ hidden_size=hidden_size,
843
+ dtype=self.model_config.dtype,
844
+ topk=self.topk,
845
+ capture_hidden_mode=CaptureHiddenMode.LAST,
846
+ )
847
+
848
+ batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
849
+ batch.spec_info.num_tokens_for_logprob_per_batch = 1
850
+ batch.spec_info.prepare_extend_after_decode(
851
+ batch,
852
+ self.speculative_num_steps,
853
+ )
854
+ batch.forward_mode = (
855
+ ForwardMode.DRAFT_EXTEND
856
+ if not batch.forward_mode.is_idle()
857
+ else ForwardMode.IDLE
858
+ )
859
+
845
860
  batch.return_hidden_states = False
846
861
  model_worker_batch = batch.get_model_worker_batch()
847
- model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
848
862
  assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
849
863
  forward_batch = ForwardBatch.init_new(
850
864
  model_worker_batch, self.draft_model_runner
@@ -869,12 +883,13 @@ class EAGLEWorker(TpModelWorker):
869
883
  )
870
884
  forward_batch.spec_info.hidden_states = logits_output.hidden_states
871
885
  else:
886
+ forward_batch.can_run_dp_cuda_graph = False
872
887
  if not forward_batch.forward_mode.is_idle():
873
888
  self.draft_model_runner.attn_backend.init_forward_metadata(
874
889
  forward_batch
875
890
  )
876
- logits_output = self.draft_model_runner.model.forward(
877
- forward_batch.input_ids, forward_batch.positions, forward_batch
891
+ logits_output, _ = self.draft_model_runner.forward(
892
+ forward_batch, skip_attn_backend_init=True
878
893
  )
879
894
  self.capture_for_decode(logits_output, forward_batch.spec_info)
880
895
 
@@ -341,15 +341,18 @@ class TboDPAttentionPreparer:
341
341
 
342
342
  @staticmethod
343
343
  def _compute_global_forward_mode(forward_modes):
344
- converted_forward_modes = [
345
- ForwardMode.DECODE.value if x == ForwardMode.IDLE.value else x
346
- for x in forward_modes
344
+ forward_modes_excluding_idle = [
345
+ x for x in forward_modes if x != ForwardMode.IDLE.value
347
346
  ]
347
+
348
+ if not forward_modes_excluding_idle:
349
+ return ForwardMode.IDLE, False
350
+
348
351
  forward_mode_agree = TboDPAttentionPreparer._is_all_same(
349
- converted_forward_modes
352
+ forward_modes_excluding_idle
350
353
  )
351
354
  global_forward_mode = (
352
- ForwardMode(converted_forward_modes[0]) if forward_mode_agree else None
355
+ ForwardMode(forward_modes_excluding_idle[0]) if forward_mode_agree else None
353
356
  )
354
357
  return global_forward_mode, forward_mode_agree
355
358
 
@@ -542,6 +545,7 @@ class TboForwardBatchPreparer:
542
545
  tbo_children=None,
543
546
  global_num_tokens_gpu=None,
544
547
  global_num_tokens_cpu=None,
548
+ dp_padding_mode=None,
545
549
  gathered_buffer=gathered_buffer,
546
550
  global_num_tokens_for_logprob_gpu=None,
547
551
  global_num_tokens_for_logprob_cpu=None,