sglang 0.4.7.post1__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/constants.py +3 -0
  5. sglang/srt/conversation.py +13 -3
  6. sglang/srt/custom_op.py +5 -1
  7. sglang/srt/disaggregation/decode.py +22 -28
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  9. sglang/srt/disaggregation/mini_lb.py +34 -4
  10. sglang/srt/disaggregation/mooncake/conn.py +12 -16
  11. sglang/srt/disaggregation/prefill.py +17 -13
  12. sglang/srt/disaggregation/utils.py +46 -18
  13. sglang/srt/distributed/parallel_state.py +12 -4
  14. sglang/srt/entrypoints/engine.py +22 -28
  15. sglang/srt/entrypoints/http_server.py +149 -79
  16. sglang/srt/entrypoints/http_server_engine.py +0 -3
  17. sglang/srt/entrypoints/openai/__init__.py +0 -0
  18. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +67 -29
  19. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  20. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  21. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  22. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  23. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  24. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  25. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  26. sglang/srt/entrypoints/openai/utils.py +72 -0
  27. sglang/srt/function_call/base_format_detector.py +7 -4
  28. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  29. sglang/srt/function_call/ebnf_composer.py +64 -10
  30. sglang/srt/function_call/function_call_parser.py +6 -6
  31. sglang/srt/function_call/llama32_detector.py +1 -1
  32. sglang/srt/function_call/mistral_detector.py +1 -1
  33. sglang/srt/function_call/pythonic_detector.py +1 -1
  34. sglang/srt/function_call/qwen25_detector.py +1 -1
  35. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  36. sglang/srt/layers/activation.py +21 -3
  37. sglang/srt/layers/attention/aiter_backend.py +5 -2
  38. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  39. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  40. sglang/srt/layers/attention/flashattention_backend.py +19 -9
  41. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  42. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  43. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  44. sglang/srt/layers/attention/tbo_backend.py +3 -3
  45. sglang/srt/layers/attention/triton_backend.py +19 -11
  46. sglang/srt/layers/communicator.py +5 -5
  47. sglang/srt/layers/dp_attention.py +11 -2
  48. sglang/srt/layers/layernorm.py +29 -2
  49. sglang/srt/layers/logits_processor.py +2 -2
  50. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  51. sglang/srt/layers/moe/ep_moe/layer.py +207 -1
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +6 -0
  54. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  55. sglang/srt/layers/moe/topk.py +91 -4
  56. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  57. sglang/srt/layers/quantization/fp8.py +25 -17
  58. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  59. sglang/srt/layers/quantization/utils.py +5 -2
  60. sglang/srt/layers/rotary_embedding.py +42 -2
  61. sglang/srt/layers/sampler.py +1 -1
  62. sglang/srt/lora/lora_manager.py +173 -74
  63. sglang/srt/lora/mem_pool.py +49 -45
  64. sglang/srt/lora/utils.py +1 -1
  65. sglang/srt/managers/cache_controller.py +33 -15
  66. sglang/srt/managers/io_struct.py +9 -12
  67. sglang/srt/managers/schedule_batch.py +40 -31
  68. sglang/srt/managers/schedule_policy.py +70 -56
  69. sglang/srt/managers/scheduler.py +147 -62
  70. sglang/srt/managers/template_manager.py +226 -0
  71. sglang/srt/managers/tokenizer_manager.py +11 -8
  72. sglang/srt/managers/tp_worker.py +12 -2
  73. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  74. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  75. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  76. sglang/srt/mem_cache/chunk_cache.py +11 -16
  77. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  78. sglang/srt/mem_cache/memory_pool.py +118 -114
  79. sglang/srt/mem_cache/radix_cache.py +20 -16
  80. sglang/srt/model_executor/cuda_graph_runner.py +76 -45
  81. sglang/srt/model_executor/forward_batch_info.py +18 -5
  82. sglang/srt/model_executor/model_runner.py +22 -6
  83. sglang/srt/model_loader/loader.py +8 -1
  84. sglang/srt/model_loader/weight_utils.py +11 -2
  85. sglang/srt/models/deepseek_nextn.py +29 -27
  86. sglang/srt/models/deepseek_v2.py +108 -26
  87. sglang/srt/models/glm4.py +312 -0
  88. sglang/srt/models/mimo_mtp.py +2 -18
  89. sglang/srt/reasoning_parser.py +21 -11
  90. sglang/srt/server_args.py +36 -8
  91. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  92. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  93. sglang/srt/speculative/eagle_utils.py +80 -8
  94. sglang/srt/speculative/eagle_worker.py +124 -41
  95. sglang/srt/torch_memory_saver_adapter.py +19 -15
  96. sglang/srt/utils.py +177 -11
  97. sglang/test/test_block_fp8_ep.py +1 -0
  98. sglang/test/test_utils.py +1 -0
  99. sglang/version.py +1 -1
  100. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/METADATA +4 -10
  101. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/RECORD +104 -93
  102. sglang/srt/entrypoints/verl_engine.py +0 -179
  103. sglang/srt/openai_api/adapter.py +0 -2148
  104. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  105. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  106. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,12 @@ from sglang.srt.model_executor.forward_batch_info import (
20
20
  ForwardMode,
21
21
  )
22
22
  from sglang.srt.speculative.eagle_utils import EagleDraftInput
23
+ from sglang.srt.utils import (
24
+ require_attn_tp_gather,
25
+ require_gathered_buffer,
26
+ require_mlp_sync,
27
+ require_mlp_tp_gather,
28
+ )
23
29
 
24
30
  if TYPE_CHECKING:
25
31
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -38,6 +44,12 @@ class EAGLEDraftCudaGraphRunner:
38
44
  self.output_buffers = {}
39
45
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
40
46
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
47
+ self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
48
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
49
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
50
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
51
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
52
+ self.dp_size = self.model_runner.dp_size
41
53
  self.tp_size = self.model_runner.tp_size
42
54
  self.topk = model_runner.server_args.speculative_eagle_topk
43
55
  self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -53,7 +65,9 @@ class EAGLEDraftCudaGraphRunner:
53
65
  # Attention backend
54
66
  self.max_bs = max(self.capture_bs)
55
67
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
56
- self.model_runner.draft_attn_backend.init_cuda_graph_state(self.max_num_token)
68
+ self.model_runner.draft_attn_backend.init_cuda_graph_state(
69
+ self.max_bs, self.max_num_token
70
+ )
57
71
  self.seq_len_fill_value = self.model_runner.draft_attn_backend.attn_backends[
58
72
  0
59
73
  ].get_cuda_graph_seq_len_fill_value()
@@ -78,10 +92,32 @@ class EAGLEDraftCudaGraphRunner:
78
92
  self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
79
93
  self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
80
94
  self.hidden_states = torch.zeros(
81
- (self.max_num_token, self.model_runner.model_config.hidden_size),
95
+ (self.max_bs, self.model_runner.model_config.hidden_size),
82
96
  dtype=self.model_runner.dtype,
83
97
  )
84
98
 
99
+ if self.require_gathered_buffer:
100
+ self.gathered_buffer = torch.zeros(
101
+ (
102
+ self.max_num_token,
103
+ self.model_runner.model_config.hidden_size,
104
+ ),
105
+ dtype=self.model_runner.dtype,
106
+ )
107
+ if self.require_mlp_tp_gather:
108
+ self.global_num_tokens_gpu = torch.zeros(
109
+ (self.dp_size,), dtype=torch.int32
110
+ )
111
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
112
+ (self.dp_size,), dtype=torch.int32
113
+ )
114
+ else:
115
+ assert self.require_attn_tp_gather
116
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
117
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
118
+ (1,), dtype=torch.int32
119
+ )
120
+
85
121
  # Capture
86
122
  try:
87
123
  with model_capture_mode():
@@ -92,11 +128,24 @@ class EAGLEDraftCudaGraphRunner:
92
128
  )
93
129
 
94
130
  def can_run(self, forward_batch: ForwardBatch):
131
+ if self.require_mlp_tp_gather:
132
+ cuda_graph_bs = (
133
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
134
+ if self.model_runner.spec_algorithm.is_eagle()
135
+ else sum(forward_batch.global_num_tokens_cpu)
136
+ )
137
+ else:
138
+ cuda_graph_bs = forward_batch.batch_size
139
+
95
140
  is_bs_supported = (
96
- forward_batch.batch_size in self.graphs
141
+ cuda_graph_bs in self.graphs
97
142
  if self.disable_padding
98
- else forward_batch.batch_size <= self.max_bs
143
+ else cuda_graph_bs <= self.max_bs
99
144
  )
145
+
146
+ if self.require_mlp_sync:
147
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
148
+
100
149
  return is_bs_supported
101
150
 
102
151
  def capture(self):
@@ -116,8 +165,58 @@ class EAGLEDraftCudaGraphRunner:
116
165
  topk_index = self.topk_index[:num_seqs]
117
166
  hidden_states = self.hidden_states[:num_seqs]
118
167
 
168
+ if self.require_mlp_tp_gather:
169
+ self.global_num_tokens_gpu.copy_(
170
+ torch.tensor(
171
+ [
172
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
173
+ for i in range(self.dp_size)
174
+ ],
175
+ dtype=torch.int32,
176
+ device=self.input_ids.device,
177
+ )
178
+ )
179
+ self.global_num_tokens_for_logprob_gpu.copy_(
180
+ torch.tensor(
181
+ [
182
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
183
+ for i in range(self.dp_size)
184
+ ],
185
+ dtype=torch.int32,
186
+ device=self.input_ids.device,
187
+ )
188
+ )
189
+ global_num_tokens = self.global_num_tokens_gpu
190
+ gathered_buffer = self.gathered_buffer[:num_tokens]
191
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
192
+ elif self.require_attn_tp_gather:
193
+ self.global_num_tokens_gpu.copy_(
194
+ torch.tensor(
195
+ [num_tokens],
196
+ dtype=torch.int32,
197
+ device=self.input_ids.device,
198
+ )
199
+ )
200
+ self.global_num_tokens_for_logprob_gpu.copy_(
201
+ torch.tensor(
202
+ [num_tokens],
203
+ dtype=torch.int32,
204
+ device=self.input_ids.device,
205
+ )
206
+ )
207
+ global_num_tokens = self.global_num_tokens_gpu
208
+ gathered_buffer = self.gathered_buffer[:num_tokens]
209
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
210
+ else:
211
+ global_num_tokens = None
212
+ gathered_buffer = None
213
+ global_num_tokens_for_logprob = None
214
+
119
215
  spec_info = EagleDraftInput(
120
- topk_p=topk_p, topk_index=topk_index, hidden_states=hidden_states
216
+ topk_p=topk_p,
217
+ topk_index=topk_index,
218
+ hidden_states=hidden_states,
219
+ capture_hidden_mode=CaptureHiddenMode.LAST,
121
220
  )
122
221
 
123
222
  # Forward batch
@@ -133,11 +232,14 @@ class EAGLEDraftCudaGraphRunner:
133
232
  seq_lens_sum=seq_lens.sum().item(),
134
233
  return_logprob=False,
135
234
  positions=positions,
235
+ global_num_tokens_gpu=global_num_tokens,
236
+ gathered_buffer=gathered_buffer,
136
237
  spec_algorithm=self.model_runner.spec_algorithm,
137
238
  spec_info=spec_info,
138
239
  capture_hidden_mode=(
139
240
  spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
140
241
  ),
242
+ global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
141
243
  )
142
244
 
143
245
  # Attention backend
@@ -147,6 +249,9 @@ class EAGLEDraftCudaGraphRunner:
147
249
 
148
250
  # Run and capture
149
251
  def run_once():
252
+ # Clean intermediate result cache for DP attention
253
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
254
+
150
255
  # Backup two fields, which will be modified in-place in `draft_forward`.
151
256
  output_cache_loc_backup = forward_batch.out_cache_loc
152
257
  hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -184,12 +289,19 @@ class EAGLEDraftCudaGraphRunner:
184
289
  raw_num_token = raw_bs * self.num_tokens_per_bs
185
290
 
186
291
  # Pad
187
- index = bisect.bisect_left(self.capture_bs, raw_bs)
292
+ if self.require_mlp_tp_gather:
293
+ total_batch_size = (
294
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
295
+ if self.model_runner.spec_algorithm.is_eagle()
296
+ else sum(forward_batch.global_num_tokens_cpu)
297
+ )
298
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
299
+ else:
300
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
188
301
  bs = self.capture_bs[index]
189
302
  if bs != raw_bs:
190
- self.seq_lens.fill_(1)
303
+ self.seq_lens.fill_(self.seq_len_fill_value)
191
304
  self.out_cache_loc.zero_()
192
- self.positions.zero_()
193
305
 
194
306
  num_tokens = bs * self.num_tokens_per_bs
195
307
 
@@ -204,6 +316,13 @@ class EAGLEDraftCudaGraphRunner:
204
316
  self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
205
317
  self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
206
318
 
319
+ if self.require_gathered_buffer:
320
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
321
+ self.global_num_tokens_for_logprob_gpu.copy_(
322
+ forward_batch.global_num_tokens_for_logprob_gpu
323
+ )
324
+ forward_batch.gathered_buffer = self.gathered_buffer
325
+
207
326
  # Attention backend
208
327
  if bs != raw_bs:
209
328
  forward_batch.batch_size = bs
@@ -212,14 +331,16 @@ class EAGLEDraftCudaGraphRunner:
212
331
  forward_batch.positions = self.positions[:num_tokens]
213
332
 
214
333
  # Special handle for seq_len_cpu used when flashinfer mla is used
215
- if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
216
- self.seq_lens_cpu.fill_(1)
334
+ if forward_batch.seq_lens_cpu is not None:
335
+ if bs != raw_bs:
336
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
217
337
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
218
338
  forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]
219
339
 
220
340
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
221
341
  forward_batch, bs
222
342
  )
343
+ # TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph
223
344
 
224
345
  # Replay
225
346
  self.graphs[bs].replay()
@@ -21,6 +21,12 @@ from sglang.srt.model_executor.forward_batch_info import (
21
21
  ForwardMode,
22
22
  )
23
23
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
24
+ from sglang.srt.utils import (
25
+ require_attn_tp_gather,
26
+ require_gathered_buffer,
27
+ require_mlp_sync,
28
+ require_mlp_tp_gather,
29
+ )
24
30
 
25
31
  if TYPE_CHECKING:
26
32
  from sglang.srt.speculative.eagle_worker import EAGLEWorker
@@ -35,6 +41,10 @@ class EAGLEDraftExtendCudaGraphRunner:
35
41
  self.output_buffers = {}
36
42
  self.enable_torch_compile = model_runner.server_args.enable_torch_compile
37
43
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
44
+ self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
45
+ self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
46
+ self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
47
+ self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
38
48
  self.tp_size = self.model_runner.tp_size
39
49
  self.dp_size = model_runner.server_args.dp_size
40
50
  self.speculative_num_steps = model_runner.server_args.speculative_num_steps
@@ -51,7 +61,7 @@ class EAGLEDraftExtendCudaGraphRunner:
51
61
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
52
62
 
53
63
  self.eagle_worker.draft_extend_attn_backend.init_cuda_graph_state(
54
- self.max_num_token
64
+ self.max_bs, self.max_num_token
55
65
  )
56
66
  self.seq_len_fill_value = (
57
67
  self.eagle_worker.draft_extend_attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -90,6 +100,27 @@ class EAGLEDraftExtendCudaGraphRunner:
90
100
  (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
91
101
  )
92
102
 
103
+ if self.require_gathered_buffer:
104
+ self.gathered_buffer = torch.zeros(
105
+ (
106
+ self.max_num_token,
107
+ self.model_runner.model_config.hidden_size,
108
+ ),
109
+ dtype=self.model_runner.dtype,
110
+ )
111
+ if self.require_mlp_tp_gather:
112
+ self.global_num_tokens_gpu = torch.zeros(
113
+ (self.dp_size,), dtype=torch.int32
114
+ )
115
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
116
+ (self.dp_size,), dtype=torch.int32
117
+ )
118
+ else:
119
+ assert self.require_attn_tp_gather
120
+ self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
121
+ self.global_num_tokens_for_logprob_gpu = torch.zeros(
122
+ (1,), dtype=torch.int32
123
+ )
93
124
  # Capture
94
125
  try:
95
126
  with model_capture_mode():
@@ -100,14 +131,24 @@ class EAGLEDraftExtendCudaGraphRunner:
100
131
  )
101
132
 
102
133
  def can_run(self, forward_batch: ForwardBatch):
103
- batch_size = forward_batch.seq_lens.numel()
134
+ if self.require_mlp_tp_gather:
135
+ cuda_graph_bs = (
136
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
137
+ if self.model_runner.spec_algorithm.is_eagle()
138
+ else sum(forward_batch.global_num_tokens_cpu)
139
+ )
140
+ else:
141
+ cuda_graph_bs = forward_batch.seq_lens.numel()
104
142
 
105
143
  is_bs_supported = (
106
- batch_size in self.graphs
144
+ cuda_graph_bs in self.graphs
107
145
  if self.disable_padding
108
- else batch_size <= self.max_bs
146
+ else cuda_graph_bs <= self.max_bs
109
147
  )
110
148
 
149
+ if self.require_mlp_sync:
150
+ is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
151
+
111
152
  return is_bs_supported
112
153
 
113
154
  def capture(self):
@@ -128,6 +169,53 @@ class EAGLEDraftExtendCudaGraphRunner:
128
169
  positions = self.positions[:num_tokens]
129
170
  hidden_states = self.hidden_states[:num_tokens]
130
171
 
172
+ if self.require_mlp_tp_gather:
173
+ self.global_num_tokens_gpu.copy_(
174
+ torch.tensor(
175
+ [
176
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
177
+ for i in range(self.dp_size)
178
+ ],
179
+ dtype=torch.int32,
180
+ device=self.input_ids.device,
181
+ )
182
+ )
183
+ self.global_num_tokens_for_logprob_gpu.copy_(
184
+ torch.tensor(
185
+ [
186
+ num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
187
+ for i in range(self.dp_size)
188
+ ],
189
+ dtype=torch.int32,
190
+ device=self.input_ids.device,
191
+ )
192
+ )
193
+ global_num_tokens = self.global_num_tokens_gpu
194
+ gathered_buffer = self.gathered_buffer[:num_tokens]
195
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
196
+ elif self.require_attn_tp_gather:
197
+ self.global_num_tokens_gpu.copy_(
198
+ torch.tensor(
199
+ [num_tokens],
200
+ dtype=torch.int32,
201
+ device=self.input_ids.device,
202
+ )
203
+ )
204
+ self.global_num_tokens_for_logprob_gpu.copy_(
205
+ torch.tensor(
206
+ [num_tokens],
207
+ dtype=torch.int32,
208
+ device=self.input_ids.device,
209
+ )
210
+ )
211
+ global_num_tokens = self.global_num_tokens_gpu
212
+ gathered_buffer = self.gathered_buffer[:num_tokens]
213
+ global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
214
+ else:
215
+ global_num_tokens = None
216
+ gathered_buffer = None
217
+ global_num_tokens_for_logprob = None
218
+
131
219
  spec_info = EagleDraftInput(
132
220
  hidden_states=hidden_states,
133
221
  accept_length=accept_length,
@@ -147,6 +235,9 @@ class EAGLEDraftExtendCudaGraphRunner:
147
235
  seq_lens_sum=seq_lens.sum().item(),
148
236
  return_logprob=False,
149
237
  positions=positions,
238
+ global_num_tokens_gpu=global_num_tokens,
239
+ global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob,
240
+ gathered_buffer=gathered_buffer,
150
241
  spec_algorithm=self.model_runner.spec_algorithm,
151
242
  spec_info=spec_info,
152
243
  capture_hidden_mode=CaptureHiddenMode.LAST,
@@ -167,6 +258,9 @@ class EAGLEDraftExtendCudaGraphRunner:
167
258
 
168
259
  # Run and capture
169
260
  def run_once():
261
+ # Clean intermediate result cache for DP attention
262
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
263
+
170
264
  # Backup two fields, which will be modified in-place in `draft_forward`.
171
265
  output_cache_loc_backup = forward_batch.out_cache_loc
172
266
  hidden_states_backup = forward_batch.spec_info.hidden_states
@@ -203,38 +297,57 @@ class EAGLEDraftExtendCudaGraphRunner:
203
297
  # in the batch, which will not be counted as num_seqs
204
298
  raw_bs = forward_batch.batch_size
205
299
  num_tokens = forward_batch.input_ids.shape[0]
300
+ if self.require_mlp_tp_gather:
301
+ total_batch_size = (
302
+ sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
303
+ if self.model_runner.spec_algorithm.is_eagle()
304
+ else sum(forward_batch.global_num_tokens_cpu)
305
+ )
306
+ index = bisect.bisect_left(self.capture_bs, total_batch_size)
307
+ else:
308
+ index = bisect.bisect_left(self.capture_bs, raw_bs)
206
309
 
207
- index = bisect.bisect_left(self.capture_bs, raw_bs)
208
310
  bs = self.capture_bs[index]
209
311
  if bs * self.num_tokens_per_bs != num_tokens:
210
- self.seq_lens.fill_(1)
211
- self.accept_length.fill_(1)
312
+ self.seq_lens.fill_(self.seq_len_fill_value)
212
313
  self.out_cache_loc.zero_()
314
+ self.accept_length.fill_(1)
315
+ self.extend_seq_lens.fill_(1)
213
316
 
214
317
  # Common inputs
215
318
  self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
216
319
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
217
- self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
320
+ if forward_batch.extend_seq_lens is not None:
321
+ self.extend_seq_lens[:raw_bs].copy_(forward_batch.extend_seq_lens)
218
322
  self.out_cache_loc[:num_tokens].copy_(forward_batch.out_cache_loc)
219
323
  self.positions[:num_tokens].copy_(forward_batch.positions)
220
324
  self.hidden_states[:num_tokens].copy_(forward_batch.spec_info.hidden_states)
221
- self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
325
+ if forward_batch.spec_info.accept_length is not None:
326
+ self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
222
327
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
223
328
 
329
+ if self.require_gathered_buffer:
330
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
331
+ self.global_num_tokens_for_logprob_gpu.copy_(
332
+ forward_batch.global_num_tokens_for_logprob_gpu
333
+ )
334
+ forward_batch.gathered_buffer = self.gathered_buffer
335
+
224
336
  if forward_batch.seq_lens_cpu is not None:
225
337
  if bs != raw_bs:
226
- self.seq_lens_cpu.fill_(1)
338
+ self.seq_lens_cpu.fill_(self.seq_len_fill_value)
227
339
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
228
340
 
229
341
  if bs != raw_bs:
342
+ forward_batch.spec_info.positions = self.positions[:num_tokens]
230
343
  forward_batch.spec_info.accept_length = self.accept_length[:bs]
231
- forward_batch.spec_info.positions = None
232
344
 
233
345
  self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
234
346
  bs=bs,
235
347
  req_pool_indices=self.req_pool_indices,
236
348
  seq_lens=self.seq_lens,
237
- seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
349
+ seq_lens_sum=forward_batch.seq_lens_sum
350
+ + (bs - raw_bs) * self.seq_len_fill_value,
238
351
  encoder_lens=None,
239
352
  forward_mode=ForwardMode.DRAFT_EXTEND,
240
353
  spec_info=forward_batch.spec_info,
@@ -21,20 +21,22 @@ from sglang.srt.managers.schedule_batch import (
21
21
  get_last_loc,
22
22
  global_server_args_dict,
23
23
  )
24
- from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
24
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
25
25
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
26
26
  from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
27
27
 
28
+ logger = logging.getLogger(__name__)
29
+
28
30
  if is_cuda():
29
31
  from sgl_kernel import (
32
+ fast_topk,
30
33
  top_k_renorm_prob,
31
34
  top_p_renorm_prob,
32
35
  tree_speculative_sampling_target_only,
33
36
  verify_tree_greedy,
34
37
  )
35
- from sgl_kernel.top_k import fast_topk
36
38
  elif is_hip():
37
- from sgl_kernel import verify_tree_greedy
39
+ from sgl_kernel import fast_topk, verify_tree_greedy
38
40
 
39
41
 
40
42
  logger = logging.getLogger(__name__)
@@ -69,6 +71,8 @@ class EagleDraftInput:
69
71
  kv_indices: torch.Tensor = None
70
72
 
71
73
  def prepare_for_extend(self, batch: ScheduleBatch):
74
+ if batch.forward_mode.is_idle():
75
+ return
72
76
  # Prefill only generate 1 token.
73
77
  assert len(self.verified_id) == len(batch.seq_lens)
74
78
 
@@ -80,6 +84,25 @@ class EagleDraftInput:
80
84
  )
81
85
  pt += extend_len
82
86
 
87
+ @classmethod
88
+ def create_idle_input(
89
+ cls,
90
+ device: torch.device,
91
+ hidden_size: int,
92
+ dtype: torch.dtype,
93
+ topk: int,
94
+ capture_hidden_mode: CaptureHiddenMode,
95
+ ):
96
+ return cls(
97
+ verified_id=None,
98
+ hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
99
+ topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
100
+ topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
101
+ capture_hidden_mode=capture_hidden_mode,
102
+ accept_length=torch.empty((0,), device=device, dtype=torch.int32),
103
+ accept_length_cpu=[],
104
+ )
105
+
83
106
  def prepare_extend_after_decode(
84
107
  self,
85
108
  batch: ScheduleBatch,
@@ -193,7 +216,35 @@ class EagleVerifyInput:
193
216
  seq_lens_cpu: torch.Tensor
194
217
  grammar: BaseGrammarObject = None
195
218
 
219
+ @classmethod
220
+ def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
221
+ return cls(
222
+ draft_token=torch.empty((0,), dtype=torch.long, device="cuda"),
223
+ custom_mask=torch.full((0,), True, dtype=torch.bool, device="cuda"),
224
+ positions=torch.empty((0,), dtype=torch.int64, device="cuda"),
225
+ retrive_index=torch.full(
226
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
227
+ ),
228
+ retrive_next_token=torch.full(
229
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
230
+ ),
231
+ retrive_next_sibling=torch.full(
232
+ (0, num_verify_tokens), -1, dtype=torch.long, device="cuda"
233
+ ),
234
+ retrive_cum_len=None,
235
+ topk=topk,
236
+ draft_token_num=num_verify_tokens,
237
+ spec_steps=spec_steps,
238
+ capture_hidden_mode=CaptureHiddenMode.FULL,
239
+ seq_lens_sum=0,
240
+ seq_lens_cpu=torch.empty((0,), dtype=torch.int32),
241
+ )
242
+
196
243
  def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
244
+
245
+ if batch.forward_mode.is_idle():
246
+ return
247
+
197
248
  batch.input_ids = self.draft_token
198
249
 
199
250
  if page_size == 1:
@@ -265,7 +316,7 @@ class EagleVerifyInput:
265
316
  self,
266
317
  batch: ScheduleBatch,
267
318
  logits_output: torch.Tensor,
268
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
319
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
269
320
  page_size: int,
270
321
  vocab_mask: Optional[torch.Tensor] = None, # For grammar
271
322
  ) -> torch.Tensor:
@@ -279,6 +330,26 @@ class EagleVerifyInput:
279
330
  tokens. I.e., logits_output.next_token_logits only contains
280
331
  accepted token logits.
281
332
  """
333
+ if batch.forward_mode.is_idle():
334
+ return EagleVerifyOutput(
335
+ draft_input=EagleDraftInput.create_idle_input(
336
+ device=batch.device,
337
+ hidden_size=batch.model_config.hidden_size,
338
+ dtype=batch.model_config.dtype,
339
+ topk=self.topk,
340
+ capture_hidden_mode=CaptureHiddenMode.LAST,
341
+ ),
342
+ logits_output=logits_output,
343
+ verified_id=torch.empty(0, dtype=torch.long, device=batch.device),
344
+ accept_length_per_req_cpu=[],
345
+ accepted_indices=torch.full(
346
+ (0, self.spec_steps + 1),
347
+ -1,
348
+ dtype=torch.int32,
349
+ device=batch.device,
350
+ ),
351
+ )
352
+
282
353
  bs = self.retrive_index.shape[0]
283
354
  candidates = self.draft_token.reshape(bs, self.draft_token_num)
284
355
  sampling_info = batch.sampling_info
@@ -992,10 +1063,11 @@ def select_top_k_tokens(
992
1063
  topk_index = topk_index.reshape(-1, topk**2)
993
1064
  input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
994
1065
 
995
- selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
996
- 0, hidden_states.shape[0], step=topk, device="cuda"
997
- ).repeat_interleave(topk)
998
- hidden_states = hidden_states[selected_input_index, :]
1066
+ if hidden_states.shape[0] > 0:
1067
+ selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
1068
+ 0, hidden_states.shape[0], step=topk, device="cuda"
1069
+ ).repeat_interleave(topk)
1070
+ hidden_states = hidden_states[selected_input_index, :]
999
1071
 
1000
1072
  tree_info = (
1001
1073
  expand_scores, # shape: (b, topk, topk)