sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,12 @@ from sglang.srt.model_executor.forward_batch_info import (
21
21
  ForwardMode,
22
22
  )
23
23
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
24
+ from sglang.srt.utils import (
25
+ require_attn_tp_gather,
26
+ require_gathered_buffer,
27
+ require_mlp_sync,
28
+ require_mlp_tp_gather,
29
+ )
24
30
 
25
31
  if TYPE_CHECKING:
26
32
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -35,6 +41,10 @@ class EAGLEDraftExtendCudaGraphRunner:
35
41
  self.output_buffers = {}
36
42
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
37
43
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
44
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
45
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
46
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
47
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
38
48
  self.tp_size = self.model_runner.tp_size
39
49
  self.dp_size = model_runner.server_args.dp_size
40
50
  self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -51,7 +61,7 @@ class EAGLEDraftExtendCudaGraphRunner:
51
61
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
52
62
 
53
63
  self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
54
- self.max_num_token
64
+ self.max_bs, self.max_num_token
55
65
  )
56
66
  self.seq_len_fill_value = (
57
67
  self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -90,6 +100,27 @@ class EAGLEDraftExtendCudaGraphRunner:
90
100
  (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
91
101
  )
92
102
 
103
+ if self.require_gathered_buffer:
104
+ self.gathered_buffer = torch.zeros(
105
+ (
106
+ self.max_num_token,
107
+ self.model_runner.model_config.hidden_size,
108
+ ),
109
+ dtype=self.model_runner.dtype,
110
+ )
111
+ if self.require_mlp_tp_gather:
112
+ self.global_num_tokens_gpu = torch.zeros(
113
+ (self.dp_size,), dtype=torch.int32
114
+ )
115
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
116
+ (self.dp_size,), dtype=torch.int32
117
+ )
118
+ else:
119
+ assert self.require_attn_tp_gather
120
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
121
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
122
+ (1,), dtype=torch.int32
123
+ )
93
124
  # Capture
94
125
  try:
95
126
  with model_capture_mode():
@@ -100,14 +131,24 @@ class EAGLEDraftExtendCudaGraphRunner:
100
131
  )
101
132
 
102
133
  def can_run(self, forward_batch: ForwardBatch):
103
- batch_size = forward_batch.seq_lens.numel()
134
+ if self.require_mlp_tp_gather:
135
+ cuda_graph_bs = (
136
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
137
+ if self.model_runner.spec_algorithm.is_eagle()
138
+ else sum(forward_batch.global_num_tokens_cpu)
139
+ )
140
+ else:
141
+ cuda_graph_bs = forward_batch.seq_lens.numel()
104
142
 
105
143
  is_bs_supported = (
106
- batch_size in self.graphs
144
+ cuda_graph_bs in self.graphs
107
145
  if self.disable_padding
108
- else batch_size <= self.max_bs
146
+ else cuda_graph_bs <= self.max_bs
109
147
  )
110
148
 
149
+ if self.require_mlp_sync:
150
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
151
+
111
152
  return is_bs_supported
112
153
 
113
154
  def capture(self):
@@ -128,6 +169,53 @@ class EAGLEDraftExtendCudaGraphRunner:
128
169
  positions = self.positions[:num_tokens]
129
170
  hidden_states = self.hidden_states[:num_tokens]
130
171
 
172
+ if self.require_mlp_tp_gather:
173
+ self.global_num_tokens_gpu.copy_(
174
+ torch.tensor(
175
+ [
176
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
177
+ for i in range(self.dp_size)
178
+ ],
179
+ dtype=torch.int32,
180
+ device=self.input_ids.device,
181
+ )
182
+ )
183
+ self.global_num_tokens_for_logprob_gpu.copy_(
184
+ torch.tensor(
185
+ [
186
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
187
+ for i in range(self.dp_size)
188
+ ],
189
+ dtype=torch.int32,
190
+ device=self.input_ids.device,
191
+ )
192
+ )
193
+ global_num_tokens = self.global_num_tokens_gpu
194
+ gathered_buffer = self.gathered_buffer[:num_tokens]
195
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
196
+ elif self.require_attn_tp_gather:
197
+ self.global_num_tokens_gpu.copy_(
198
+ torch.tensor(
199
+ [num_tokens],
200
+ dtype=torch.int32,
201
+ device=self.input_ids.device,
202
+ )
203
+ )
204
+ self.global_num_tokens_for_logprob_gpu.copy_(
205
+ torch.tensor(
206
+ [num_tokens],
207
+ dtype=torch.int32,
208
+ device=self.input_ids.device,
209
+ )
210
+ )
211
+ global_num_tokens = self.global_num_tokens_gpu
212
+ gathered_buffer = self.gathered_buffer[:num_tokens]
213
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
214
+ else:
215
+ global_num_tokens = None
216
+ gathered_buffer = None
217
+ global_num_tokens_for_logprob = None
218
+
131
219
  spec_info = EagleDraftInput(
132
220
  hidden_states=hidden_states,
133
221
  accept_length=accept_length,
@@ -147,6 +235,9 @@ class EAGLEDraftExtendCudaGraphRunner:
147
235
  seq_lens_sum=seq_lens.sum().item(),
148
236
  return_logprob=False,
149
237
  positions=positions,
238
+ global_num_tokens_gpu=global_num_tokens,
239
+ global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
240
+ gathered_buffer=gathered_buffer,
150
241
  spec_algorithm=self.model_runner.spec_algorithm,
151
242
  spec_info=spec_info,
152
243
  capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -167,6 +258,9 @@ class EAGLEDraftExtendCudaGraphRunner:
167
258
 
168
259
  # Run and capture
169
260
  def run_once():
261
+ # Clean intermediate result cache for DP attention
262
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
263
+
170
264
  # Backup two fields, which will be modified in-place in `draft_forward`.
171
265
  output_cache_loc_backup = forward_batch.out_cache_loc
172
266
  hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -203,38 +297,57 @@ class EAGLEDraftExtendCudaGraphRunner:
203
297
  # in the batch, which will not be counted as num_seqs
204
298
  raw_bs = forward_batch.batch_size
205
299
  num_tokens = forward_batch.input_ids.shape[0]
300
+ if self.require_mlp_tp_gather:
301
+ total_batch_size = (
302
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
303
+ if self.model_runner.spec_algorithm.is_eagle()
304
+ else sum(forward_batch.global_num_tokens_cpu)
305
+ )
306
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
307
+ else:
308
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
206
309
 
207
- index = bisect.bisect_left(self.capture_bs, raw_bs)
208
310
  bs = self.capture_bs[index]
209
311
  if bs * self.num_tokens_per_bs != num_tokens:
210
- self.seq_lens.fill_(1)
211
- self.accept_length.fill_(1)
312
+ self.seq_lens.fill_(self.seq_len_fill_value)
212
313
  self.out_cache_loc.zero_()
314
+ self.accept_length.fill_(1)
315
+ self.extend_seq_lens.fill_(1)
213
316
 
214
317
  # Common inputs
215
318
  self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
216
319
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
217
- self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
320
+ if forward_batch.extend_seq_lens is not None:
321
+ self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
218
322
  self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
219
323
  self.positions[:num_tokens].copy_(forward_batch.positions)
220
324
  self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
221
- self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
325
+ if forward_batch.spec_info.accept_length is not None:
326
+ self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
222
327
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
223
328
 
329
+ if self.require_gathered_buffer:
330
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
331
+ self.global_num_tokens_for_logprob_gpu.copy_(
332
+ forward_batch.global_num_tokens_for_logprob_gpu
333
+ )
334
+ forward_batch.gathered_buffer = self.gathered_buffer
335
+
224
336
  if forward_batch.seq_lens_cpu is not None:
225
337
  if bs != raw_bs:
226
- self.seq_lens_cpu.fill_(1)
338
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
227
339
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
228
340
 
229
341
  if bs != raw_bs:
342
+ forward_batch.spec_info.positions = self.positions[:num_tokens]
230
343
  forward_batch.spec_info.accept_length = self.accept_length[:bs]
231
- forward_batch.spec_info.positions = None
232
344
 
233
345
  self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
234
346
  bs=bs,
235
347
  req_pool_indices=self.req_pool_indices,
236
348
  seq_lens=self.seq_lens,
237
- seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
349
+ seq_lens_sum=forward_batch.seq_lens_sum
350
+ + (bs - raw_bs) * self.seq_len_fill_value,
238
351
  encoder_lens=None,
239
352
  forward_mode=ForwardMode.DRAFT_EXTEND,
240
353
  spec_info=forward_batch.spec_info,
@@ -21,20 +21,22 @@ from sglang.srt.managers.schedule_batch import (
21
21
  get_last_loc,
22
22
  global_server_args_dict,
23
23
  )
24
- from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
24
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
25
25
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
26
26
  from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
27
27
 
28
+ logger = logging.getLogger(__name__)
29
+
28
30
  if is_cuda():
29
31
  from sgl_kernel import (
32
+ fast_topk,
30
33
  top_k_renorm_prob,
31
34
  top_p_renorm_prob,
32
35
  tree_speculative_sampling_target_only,
33
36
  verify_tree_greedy,
34
37
  )
35
- from sgl_kernel.top_k import fast_topk
36
38
  elif is_hip():
37
- from sgl_kernel import verify_tree_greedy
39
+ from sgl_kernel import fast_topk, verify_tree_greedy
38
40
 
39
41
 
40
42
  logger = logging.getLogger(__name__)
@@ -69,6 +71,8 @@ class EagleDraftInput:
69
71
  kv_indices: torch.Tensor = None
70
72
 
71
73
  def prepare_for_extend(self, batch: ScheduleBatch):
74
+ if batch.forward_mode.is_idle():
75
+ return
72
76
  # Prefill only generate 1 token.
73
77
  assert len(self.verified_id) == len(batch.seq_lens)
74
78
 
@@ -80,6 +84,25 @@ class EagleDraftInput:
80
84
  )
81
85
  pt += extend_len
82
86
 
87
+ @classmethod
88
+ def create_idle_input(
89
+ cls,
90
+ device: torch.device,
91
+ hidden_size: int,
92
+ dtype: torch.dtype,
93
+ topk: int,
94
+ capture_hidden_mode: CaptureHiddenMode,
95
+ ):
96
+ return cls(
97
+ verified_id=None,
98
+ hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
99
+ topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
100
+ topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
101
+ capture_hidden_mode=capture_hidden_mode,
102
+ accept_length=torch.empty((0,), device=device, dtype=torch.int32),
103
+ accept_length_cpu=[],
104
+ )
105
+
83
106
  def prepare_extend_after_decode(
84
107
  self,
85
108
  batch: ScheduleBatch,
@@ -193,7 +216,35 @@ class EagleVerifyInput:
193
216
  seq_lens_cpu: torch.Tensor
194
217
  grammar: BaseGrammarObject = None
195
218
 
219
+ @classmethod
220
+ def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
221
+ return cls(
222
+ draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
223
+ custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
224
+ positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
225
+ retrive_index=torch.full(
226
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
227
+ ),
228
+ retrive_next_token=torch.full(
229
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
230
+ ),
231
+ retrive_next_sibling=torch.full(
232
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
233
+ ),
234
+ retrive_cum_len=None,
235
+ topk=topk,
236
+ draft_token_num=num_verify_tokens,
237
+ spec_steps=spec_steps,
238
+ capture_hidden_mode=CaptureHiddenMode.FULL,
239
+ seq_lens_sum=0,
240
+ seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
241
+ )
242
+
196
243
  def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
244
+
245
+ if batch.forward_mode.is_idle():
246
+ return
247
+
197
248
  batch.input_ids = self.draft_token
198
249
 
199
250
  if page_size == 1:
@@ -265,7 +316,7 @@ class EagleVerifyInput:
265
316
  self,
266
317
  batch: ScheduleBatch,
267
318
  logits_output: torch.Tensor,
268
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
319
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
269
320
  page_size: int,
270
321
  vocab_mask: Optional[torch.Tensor] = None, # For grammar
271
322
  ) -> torch.Tensor:
@@ -279,6 +330,26 @@ class EagleVerifyInput:
279
330
  tokens. I.e., logits_output.next_token_logits only contains
280
331
  accepted token logits.
281
332
  """
333
+ if batch.forward_mode.is_idle():
334
+ return EagleVerifyOutput(
335
+ draft_input=EagleDraftInput.create_idle_input(
336
+ device=batch.device,
337
+ hidden_size=batch.model_config.hidden_size,
338
+ dtype=batch.model_config.dtype,
339
+ topk=self.topk,
340
+ capture_hidden_mode=CaptureHiddenMode.LAST,
341
+ ),
342
+ logits_output=logits_output,
343
+ verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
344
+ accept_length_per_req_cpu=[],
345
+ accepted_indices=torch.full(
346
+ (0, self.spec_steps + 1),
347
+ -1,
348
+ dtype=torch.int32,
349
+ device=batch.device,
350
+ ),
351
+ )
352
+
282
353
  bs = self.retrive_index.shape[0]
283
354
  candidates = self.draft_token.reshape(bs, self.draft_token_num)
284
355
  sampling_info = batch.sampling_info
@@ -992,10 +1063,11 @@ def select_top_k_tokens(
992
1063
  topk_index = topk_index.reshape(-1, topk**2)
993
1064
  input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
994
1065
 
995
- selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
996
- 0, hidden_states.shape[0], step=topk, device="cuda"
997
- ).repeat_interleave(topk)
998
- hidden_states = hidden_states[selected_input_index, :]
1066
+ if hidden_states.shape[0] > 0:
1067
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
1068
+ 0, hidden_states.shape[0], step=topk, device="cuda"
1069
+ ).repeat_interleave(topk)
1070
+ hidden_states = hidden_states[selected_input_index, :]
999
1071
 
1000
1072
  tree_info = (
1001
1073
  expand_scores, # shape: (b, topk, topk)